From a82ecb836247fb7984959cf1ceeea94f8930509d Mon Sep 17 00:00:00 2001 From: DamianCzajkowski <43958031+DamianCzajkowski@users.noreply.github.com> Date: Tue, 11 Jun 2024 14:22:02 +0200 Subject: [PATCH 01/11] Add custom operation generation (#296) Add custom operation generation --- ariadne_codegen/client_generators/client.py | 203 +++++++- .../client_generators/constants.py | 20 + .../client_generators/custom_fields.py | 328 +++++++++++++ .../client_generators/custom_fields_typing.py | 112 +++++ .../client_generators/custom_operation.py | 219 +++++++++ .../dependencies/base_operation.py | 92 ++++ ariadne_codegen/client_generators/package.py | 105 +++- ariadne_codegen/client_generators/utils.py | 56 +++ ariadne_codegen/codegen.py | 59 ++- ariadne_codegen/main.py | 9 +- ariadne_codegen/settings.py | 3 +- .../expected_client/__init__.py | 56 +++ .../expected_client/async_base_client.py | 370 ++++++++++++++ .../expected_client/base_model.py | 27 ++ .../expected_client/base_operation.py | 92 ++++ .../expected_client/client.py | 52 ++ .../expected_client/custom_fields.py | 292 ++++++++++++ .../expected_client/custom_mutations.py | 9 + .../expected_client/custom_queries.py | 43 ++ .../expected_client/custom_typing_fields.py | 63 +++ .../expected_client/enums.py | 9 + .../expected_client/exceptions.py | 83 ++++ .../expected_client/input_types.py | 0 .../custom_query_builder/pyproject.toml | 5 + .../custom_query_builder/schema.graphql | 249 ++++++++++ .../main/custom_operation_builder/__init__.py | 0 .../graphql_client/__init__.py | 41 ++ .../graphql_client/async_base_client.py | 370 ++++++++++++++ .../graphql_client/base_model.py | 27 ++ .../graphql_client/base_operation.py | 92 ++++ .../graphql_client/client.py | 52 ++ .../graphql_client/custom_fields.py | 100 ++++ .../graphql_client/custom_mutations.py | 63 +++ .../graphql_client/custom_queries.py | 64 +++ .../graphql_client/custom_typing_fields.py | 27 ++ .../graphql_client/enums.py | 7 + .../graphql_client/exceptions.py | 83 ++++ .../graphql_client/input_types.py | 22 + .../test_operation_build.py | 450 ++++++++++++++++++ tests/main/test_main.py | 8 + 40 files changed, 3948 insertions(+), 14 deletions(-) create mode 100644 ariadne_codegen/client_generators/custom_fields.py create mode 100644 ariadne_codegen/client_generators/custom_fields_typing.py create mode 100644 ariadne_codegen/client_generators/custom_operation.py create mode 100644 ariadne_codegen/client_generators/dependencies/base_operation.py create mode 100644 ariadne_codegen/client_generators/utils.py create mode 100644 tests/main/clients/custom_query_builder/expected_client/__init__.py create mode 100644 tests/main/clients/custom_query_builder/expected_client/async_base_client.py create mode 100644 tests/main/clients/custom_query_builder/expected_client/base_model.py create mode 100644 tests/main/clients/custom_query_builder/expected_client/base_operation.py create mode 100644 tests/main/clients/custom_query_builder/expected_client/client.py create mode 100644 tests/main/clients/custom_query_builder/expected_client/custom_fields.py create mode 100644 tests/main/clients/custom_query_builder/expected_client/custom_mutations.py create mode 100644 tests/main/clients/custom_query_builder/expected_client/custom_queries.py create mode 100644 tests/main/clients/custom_query_builder/expected_client/custom_typing_fields.py create mode 100644 tests/main/clients/custom_query_builder/expected_client/enums.py create mode 100644 tests/main/clients/custom_query_builder/expected_client/exceptions.py create mode 100644 tests/main/clients/custom_query_builder/expected_client/input_types.py create mode 100644 tests/main/clients/custom_query_builder/pyproject.toml create mode 100644 tests/main/clients/custom_query_builder/schema.graphql create mode 100644 tests/main/custom_operation_builder/__init__.py create mode 100644 tests/main/custom_operation_builder/graphql_client/__init__.py create mode 100644 tests/main/custom_operation_builder/graphql_client/async_base_client.py create mode 100644 tests/main/custom_operation_builder/graphql_client/base_model.py create mode 100644 tests/main/custom_operation_builder/graphql_client/base_operation.py create mode 100644 tests/main/custom_operation_builder/graphql_client/client.py create mode 100644 tests/main/custom_operation_builder/graphql_client/custom_fields.py create mode 100644 tests/main/custom_operation_builder/graphql_client/custom_mutations.py create mode 100644 tests/main/custom_operation_builder/graphql_client/custom_queries.py create mode 100644 tests/main/custom_operation_builder/graphql_client/custom_typing_fields.py create mode 100644 tests/main/custom_operation_builder/graphql_client/enums.py create mode 100644 tests/main/custom_operation_builder/graphql_client/exceptions.py create mode 100644 tests/main/custom_operation_builder/graphql_client/input_types.py create mode 100644 tests/main/custom_operation_builder/test_operation_build.py diff --git a/ariadne_codegen/client_generators/client.py b/ariadne_codegen/client_generators/client.py index 956bb9be..5b245463 100644 --- a/ariadne_codegen/client_generators/client.py +++ b/ariadne_codegen/client_generators/client.py @@ -14,10 +14,13 @@ generate_await, generate_call, generate_class_def, + generate_comp, generate_constant, generate_expr, generate_import_from, generate_keyword, + generate_list, + generate_list_comp, generate_method_definition, generate_module, generate_name, @@ -32,11 +35,20 @@ from .constants import ( ANY, ASYNC_ITERATOR, + BASE_GRAPHQL_FIELD_CLASS_NAME, + BASE_OPERATION_FILE_PATH, DICT, + DOCUMENT_NODE, + GRAPHQL_MODULE, KWARGS_NAMES, LIST, MODEL_VALIDATE_METHOD, + NAME_NODE, + OPERATION_DEFINITION_NODE, + OPERATION_TYPE, OPTIONAL, + PRINT_AST, + SELECTION_SET_NODE, TYPING_MODULE, UNION, UNSET_IMPORT, @@ -66,10 +78,18 @@ def __init__( self.custom_scalars = custom_scalars if custom_scalars else {} self.arguments_generator = arguments_generator - self._imports: List[ast.ImportFrom] = [] + self._imports: List[Union[ast.ImportFrom, ast.Import]] = [] self._add_import( generate_import_from( - [OPTIONAL, LIST, DICT, ANY, UNION, ASYNC_ITERATOR], TYPING_MODULE + [ + OPTIONAL, + LIST, + DICT, + ANY, + UNION, + ASYNC_ITERATOR, + ], + TYPING_MODULE, ) ) self._add_import(base_client_import) @@ -187,6 +207,185 @@ def add_method( generate_import_from(names=[return_type], from_=return_type_module, level=1) ) + def add_execute_custom_operation_method(self): + self._add_import( + generate_import_from( + [ + DOCUMENT_NODE, + OPERATION_DEFINITION_NODE, + NAME_NODE, + SELECTION_SET_NODE, + PRINT_AST, + ], + GRAPHQL_MODULE, + ) + ) + self._add_import( + generate_import_from( + [BASE_GRAPHQL_FIELD_CLASS_NAME], BASE_OPERATION_FILE_PATH.stem, level=1 + ) + ) + execute_await = generate_await( + value=generate_call( + func=generate_attribute(value=generate_name("self"), attr="execute"), + args=[ + generate_call( + func=generate_name("print_ast"), + args=[generate_name("operation_ast")], + ) + ], + keywords=[ + generate_keyword( + arg="operation_name", value=generate_name("operation_name") + ) + ], + ) + ) + + operation_definition_node = generate_call( + func=generate_name("OperationDefinitionNode"), + keywords=[ + generate_keyword( + arg="operation", value=generate_name("operation_type") + ), + generate_keyword( + arg="name", + value=generate_call( + func=generate_name("NameNode"), + keywords=[ + generate_keyword( + arg="value", value=generate_name("operation_name") + ) + ], + ), + ), + generate_keyword( + arg="selection_set", + value=generate_call( + func=generate_name("SelectionSetNode"), + keywords=[ + generate_keyword( + arg="selections", + value=generate_list_comp( + elt=generate_call( + func=generate_attribute( + value=generate_name("field"), + attr="to_ast", + ), + ), + generators=[ + generate_comp( + target="field", + iter_="fields", + ) + ], + ), + ) + ], + ), + ), + ], + ) + operation_ast = generate_call( + func=generate_name("DocumentNode"), + keywords=[ + generate_keyword( + arg="definitions", + value=generate_list(elements=[operation_definition_node]), + ) + ], + ) + body_return = generate_return( + value=generate_call( + func=generate_attribute(value=generate_name("self"), attr="get_data"), + args=[generate_name("response")], + ) + ) + async_def_node = generate_async_method_definition( + name="execute_custom_operation", + arguments=generate_arguments( + args=[ + generate_arg("self"), + generate_arg( + "*fields", + annotation=generate_name("GraphQLField"), + ), + generate_arg( + "operation_type", + annotation=generate_name("OperationType"), + ), + generate_arg("operation_name", annotation=generate_name("str")), + ], + ), + body=[ + generate_assign( + targets=["operation_ast"], + value=operation_ast, + ), + generate_assign( + targets=["response"], + value=execute_await, + ), + body_return, + ], + return_type=generate_subscript( + generate_name(DICT), + generate_tuple([generate_name("str"), generate_name("Any")]), + ), + ) + self._class_def.body.append(async_def_node) + + def create_custom_operation_method(self, name, operation_type): + self._add_import( + generate_import_from( + [ + OPERATION_TYPE, + ], + GRAPHQL_MODULE, + ) + ) + body_return = generate_return( + value=generate_await( + value=generate_call( + func=generate_attribute( + value=generate_name("self"), + attr="execute_custom_operation", + ), + args=[ + generate_name("*fields"), + ], + keywords=[ + generate_keyword( + arg="operation_type", + value=generate_attribute( + value=generate_name("OperationType"), + attr=operation_type, + ), + ), + generate_keyword( + arg="operation_name", value=generate_name("operation_name") + ), + ], + ) + ) + ) + async_def_query = generate_async_method_definition( + name=name, + arguments=generate_arguments( + args=[ + generate_arg("self"), + generate_arg("*fields", annotation=generate_name("GraphQLField")), + generate_arg("operation_name", annotation=generate_name("str")), + ], + ), + body=[body_return], + return_type=generate_subscript( + generate_name(DICT), + generate_tuple([generate_name("str"), generate_name("Any")]), + ), + ) + self._class_def.body.append(async_def_query) + def get_variable_names(self, arguments: ast.arguments) -> Dict[str, str]: mapped_variable_names = [ self._operation_str_variable, diff --git a/ariadne_codegen/client_generators/constants.py b/ariadne_codegen/client_generators/constants.py index fe614d27..f9927e6b 100644 --- a/ariadne_codegen/client_generators/constants.py +++ b/ariadne_codegen/client_generators/constants.py @@ -16,17 +16,34 @@ LIST = "List" UNION = "Union" ANY = "Any" +TYPE = "Type" +TYPE_CHECKING = "TYPE_CHECKING" DICT = "Dict" CALLABLE = "Callable" ANNOTATED = "Annotated" LITERAL = "Literal" ASYNC_ITERATOR = "AsyncIterator" +DOCUMENT_NODE = "DocumentNode" +OPERATION_DEFINITION_NODE = "OperationDefinitionNode" +NAME_NODE = "NameNode" +SELECTION_SET_NODE = "SelectionSetNode" +PRINT_AST = "print_ast" +OPERATION_TYPE = "OperationType" + +HTTPX = "httpx" +HTTPX_RESPONSE = "httpx.Response" TIMESTAMP_COMMENT = "# Generated by ariadne-codegen on {}" STABLE_COMMENT = "# Generated by ariadne-codegen" SOURCE_COMMENT = "# Source: {}" COMMENT_DATETIME_FORMAT = "%Y-%m-%d %H:%M" +BASE_OPERATION_FILE_PATH = Path(__file__).parent / "dependencies" / "base_operation.py" +BASE_GRAPHQL_OPERATION_CLASS_NAME = "BaseGraphQLOperation" +BASE_GRAPHQL_FIELD_CLASS_NAME = "GraphQLField" +CUSTOM_FIELDS_FILE_PATH = Path(__file__).parent / "custom_fields.py" +CUSTOM_FIELDS_TYPING_FILE_PATH = Path(__file__).parent / "custom_typing_fields.py" + BASE_MODEL_FILE_PATH = Path(__file__).parent / "dependencies" / "base_model.py" BASE_MODEL_CLASS_NAME = "BaseModel" BASE_MODEL_IMPORT = ast.ImportFrom( @@ -49,6 +66,7 @@ TYPENAME_ALIAS = "typename__" TYPING_MODULE = "typing" +GRAPHQL_MODULE = "graphql" PYDANTIC_MODULE = "pydantic" FIELD_CLASS = "Field" ALIAS_KEYWORD = "alias" @@ -100,3 +118,5 @@ SCALARS_PARSE_DICT_NAME = "SCALARS_PARSE_FUNCTIONS" SCALARS_SERIALIZE_DICT_NAME = "SCALARS_SERIALIZE_FUNCTIONS" + +OPERATION_TYPES = ("Query", "Mutation", "Subscription") diff --git a/ariadne_codegen/client_generators/custom_fields.py b/ariadne_codegen/client_generators/custom_fields.py new file mode 100644 index 00000000..b223b40c --- /dev/null +++ b/ariadne_codegen/client_generators/custom_fields.py @@ -0,0 +1,328 @@ +import ast +from typing import Dict, List, Optional, Set, Union, cast + +from graphql import ( + GraphQLEnumType, + GraphQLInputObjectType, + GraphQLInterfaceType, + GraphQLObjectType, + GraphQLScalarType, + GraphQLSchema, + GraphQLUnionType, +) + +from ariadne_codegen.exceptions import ParsingError + +from ..codegen import ( + generate_ann_assign, + generate_annotation_name, + generate_arg, + generate_arguments, + generate_attribute, + generate_call, + generate_class_def, + generate_constant, + generate_expr, + generate_import_from, + generate_keyword, + generate_method_definition, + generate_module, + generate_name, + generate_return, + generate_subscript, + generate_union_annotation, +) +from ..utils import process_name +from .constants import ( + ANY, + BASE_MODEL_FILE_PATH, + BASE_OPERATION_FILE_PATH, + INPUT_SCALARS_MAP, + OPTIONAL, + TYPING_MODULE, + UNION, + UPLOAD_CLASS_NAME, +) +from .utils import TypeCollector, get_final_type + + +class CustomFieldsGenerator: + def __init__( + self, + schema: GraphQLSchema, + convert_to_snake_case: bool = True, + custom_scalars=None, + ) -> None: + self.schema = schema + self.convert_to_snake_case = convert_to_snake_case + self.custom_scalars = custom_scalars if custom_scalars else {} + self._visited_types: Set[str] = set() + self._field_classes: Set[str] = set() + self._generated_modules: Dict[str, ast.Module] = {} + self._imports: List[ast.ImportFrom] = [ + ast.ImportFrom( + module=BASE_OPERATION_FILE_PATH.stem, + names=[ast.alias("GraphQLField")], + level=1, + ) + ] + self._add_import(generate_import_from([OPTIONAL, UNION], TYPING_MODULE)) + + self._class_defs: List[ast.ClassDef] = self._parse_object_type_definitions( + TypeCollector(self.schema).collect() + ) + + def _add_import(self, import_: Optional[ast.ImportFrom] = None): + if not import_: + return + + if import_.names: + self._imports.append(import_) + + def generate(self) -> ast.Module: + module = generate_module( + body=( + cast(List[ast.stmt], self._imports) + + cast( + List[ast.stmt], + self._class_defs, + ) + ), + ) + + return module + + def _parse_object_type_definitions(self, class_definitions): + class_defs = [] + interface_defs = [] + for type_name in class_definitions: + graphql_type = self.schema.get_type(type_name) + if isinstance(graphql_type, GraphQLObjectType): + class_def = self._generate_class_def_body( + definition=graphql_type, + class_name=f"{graphql_type.name}Fields", + ) + class_defs.append(class_def) + if isinstance(graphql_type, GraphQLInterfaceType): + class_def = self._generate_class_def_body( + definition=graphql_type, + class_name=f"{graphql_type.name}Interface", + ) + class_def.body.append( + self._generate_on_method(f"{graphql_type.name}Interface") + ) + class_defs.append(class_def) + return [*interface_defs, *class_defs] + + def _generate_class_def_body( + self, + definition: Union[GraphQLObjectType, GraphQLInterfaceType], + class_name: str, + ) -> ast.ClassDef: + base_names = ["GraphQLField"] + additional_fields_typing = set() + definition_fields: Dict[str, ast.ClassDef] = dict(definition.fields.items()) + for interface in definition.interfaces: + definition_fields.update(dict(interface.fields.items())) + class_def = generate_class_def(name=class_name, base_names=base_names) + + for lineno, (org_name, field) in enumerate(definition_fields.items(), start=1): + name = process_name( + org_name, + convert_to_snake_case=self.convert_to_snake_case, + ) + final_type = get_final_type(field) + if isinstance(final_type, GraphQLObjectType): + field_name = f"{final_type.name}Fields" + class_def.body.append( + self.generate_product_type_method( + name, field_name, getattr(field, "args") + ) + ) + additional_fields_typing.add(field_name) + elif isinstance(final_type, GraphQLInterfaceType): + field_name = f"{final_type.name}Interface" + class_def.body.append( + self.generate_product_type_method( + name, field_name, getattr(field, "args") + ) + ) + additional_fields_typing.add(field_name) + else: + field_name = f"{definition.name}GraphQLField" + + if isinstance(final_type, GraphQLUnionType): + field_name = f"{final_type.name}Union" + additional_fields_typing.add(field_name) + if getattr(field, "args"): + class_def.body.append( + self.generate_product_type_method( + name, field_name, getattr(field, "args") + ) + ) + else: + self._add_import(generate_import_from([field_name], level=1)) + field_class_name = generate_name(field_name) + + field_implementation = generate_ann_assign( + target=name, + annotation=field_class_name, + value=generate_call( + func=field_class_name, + args=[generate_constant(org_name)], + ), + lineno=lineno, + ) + + class_def.body.append(field_implementation) + + class_def.body.append( + self._generate_fields_method( + class_name, definition.name, sorted(additional_fields_typing) + ) + ) + + return class_def + + def _generate_fields_method( + self, class_name: str, definition_name: str, additional_fields_typing: List + ) -> ast.FunctionDef: + field_class_name = generate_name(f"{definition_name}GraphQLField") + self._add_import( + generate_import_from([f"{definition_name}GraphQLField"], level=1) + ) + fields_annotation: Union[ast.Name, ast.Subscript] = field_class_name + if additional_fields_typing: + additional_fields_typing_ann = [ + generate_name(f'"{field_typing}"') + for field_typing in additional_fields_typing + ] + fields_annotation = generate_union_annotation( + [field_class_name, *additional_fields_typing_ann], nullable=False + ) + + return generate_method_definition( + "fields", + arguments=generate_arguments( + [ + generate_arg(name="self"), + generate_arg(name="*subfields", annotation=fields_annotation), + ] + ), + body=[ + generate_expr( + value=generate_call( + func=generate_attribute( + value=generate_name("self"), + attr="_subfields.extend", + ), + args=[generate_name("subfields")], + ) + ), + generate_return(value=generate_name("self")), + ], + return_type=generate_name(f'"{class_name}"'), + ) + + def generate_product_type_method( + self, name, class_name, arguments=None + ) -> ast.FunctionDef: + arguments = arguments or {} + return_keywords = [] + field_class_name = generate_name(class_name) + field_kwonlyargs = [] + field_kw_defaults: List[Union[ast.expr, None]] = [] + for arg_name, argument in arguments.items(): + argument_final_type = get_final_type(argument.type) + field_kwonlyargs.append( + generate_arg( + name=arg_name, + annotation=self._parse_graphql_type_name(argument_final_type), + ) + ) + field_kw_defaults.append(generate_constant(value=None)) + return_keywords.append( + generate_keyword(arg=arg_name, value=generate_name(arg_name)) + ) + return generate_method_definition( + name, + arguments=generate_arguments( + args=[generate_arg(name="cls")], + kwonlyargs=field_kwonlyargs, + kw_defaults=field_kw_defaults, + ), + body=[ + generate_return( + value=generate_call( + func=field_class_name, + args=[generate_constant(name)], + keywords=return_keywords, + ) + ), + ], + return_type=generate_name(f'"{class_name}"'), + decorator_list=[generate_name("classmethod")], + ) + + def _generate_on_method(self, class_name: str) -> ast.FunctionDef: + return generate_method_definition( + "on", + arguments=generate_arguments( + [ + generate_arg(name="self"), + generate_arg(name="type_name", annotation=generate_name("str")), + generate_arg( + name="*subfields", annotation=generate_name("GraphQLField") + ), + ] + ), + body=[ + cast( + ast.stmt, + ast.Assign( + targets=[ + generate_subscript( + value=generate_attribute( + value=generate_name("self"), + attr="_inline_fragments", + ), + slice_=generate_name("type_name"), + ) + ], + value=generate_name("subfields"), + lineno=1, + ), + ), + generate_return(value=generate_name("self")), + ], + return_type=generate_name(f'"{class_name}"'), + ) + + def _parse_graphql_type_name( + self, type_, nullable: bool = True + ) -> Union[ast.Name, ast.Subscript]: + name = type_.name + + if isinstance(type_, GraphQLInputObjectType): + self._add_import( + generate_import_from(names=[name], from_="input_types", level=1) + ) + elif isinstance(type_, GraphQLEnumType): + self._add_import(generate_import_from(names=[name], level=1)) + elif isinstance(type_, GraphQLScalarType): + if name not in self.custom_scalars: + name = INPUT_SCALARS_MAP.get(name, ANY) + if name == UPLOAD_CLASS_NAME: + self._add_import( + generate_import_from( + names=[UPLOAD_CLASS_NAME], + from_=BASE_MODEL_FILE_PATH.stem, + level=1, + ) + ) + else: + name = self.custom_scalars[name].type_name + else: + raise ParsingError(f"Incorrect argument type {name}") + + return generate_annotation_name(name, nullable) diff --git a/ariadne_codegen/client_generators/custom_fields_typing.py b/ariadne_codegen/client_generators/custom_fields_typing.py new file mode 100644 index 00000000..d38a2fc8 --- /dev/null +++ b/ariadne_codegen/client_generators/custom_fields_typing.py @@ -0,0 +1,112 @@ +import ast +from typing import List, cast + +from graphql import ( + GraphQLInterfaceType, + GraphQLObjectType, + GraphQLSchema, + GraphQLUnionType, +) + +from ariadne_codegen.client_generators.utils import get_final_type + +from ..codegen import ( + generate_arg, + generate_arguments, + generate_attribute, + generate_class_def, + generate_method_definition, + generate_module, + generate_name, + generate_return, + generate_subscript, +) +from .constants import BASE_OPERATION_FILE_PATH, OPERATION_TYPES + + +class CustomFieldsTypingGenerator: + def __init__( + self, + schema: GraphQLSchema, + ) -> None: + self.schema = schema + self.graphql_field_import = ast.ImportFrom( + module=BASE_OPERATION_FILE_PATH.stem, + names=[ast.alias("GraphQLField")], + level=1, + ) + self._public_names: List[str] = [] + self._class_defs: List[ast.ClassDef] = [ + self._generate_field_class(d) for d in self._filter_types() + ] + + def generate(self) -> ast.Module: + return generate_module( + body=( + cast(List[ast.stmt], [self.graphql_field_import]) + + cast(List[ast.stmt], [self._class_defs]) + ) + ) + + def _filter_types(self): + return [ + get_final_type(definition) + for name, definition in self.schema.type_map.items() + if isinstance( + definition, (GraphQLObjectType, GraphQLInterfaceType, GraphQLUnionType) + ) + and not name.startswith("__") + and name not in OPERATION_TYPES + ] + + def _generate_field_class(self, class_def: ast.ClassDef) -> ast.ClassDef: + class_name = f"{class_def.name}GraphQLField" + class_body: List[ast.stmt] = [] + if isinstance(class_def, GraphQLUnionType): + class_name = f"{class_def.name}Union" + class_body.append(self._generate_on_method(class_name)) + if class_name not in self._public_names: + self._public_names.append(class_name) + field_class_def = generate_class_def( + name=class_name, + base_names=["GraphQLField"], + body=class_body if class_body else cast(List[ast.stmt], [ast.Pass()]), + ) + return field_class_def + + def _generate_on_method(self, class_name: str) -> ast.FunctionDef: + return generate_method_definition( + "on", + arguments=generate_arguments( + [ + generate_arg(name="self"), + generate_arg(name="type_name", annotation=generate_name("str")), + generate_arg( + name="*subfields", annotation=generate_name("GraphQLField") + ), + ] + ), + body=[ + cast( + ast.stmt, + ast.Assign( + targets=[ + generate_subscript( + value=generate_attribute( + value=generate_name("self"), + attr="_inline_fragments", + ), + slice_=generate_name("type_name"), + ) + ], + value=generate_name("subfields"), + lineno=1, + ), + ), + generate_return(value=generate_name("self")), + ], + return_type=generate_name(f'"{class_name}"'), + ) + + def get_generated_public_names(self) -> List[str]: + return self._public_names diff --git a/ariadne_codegen/client_generators/custom_operation.py b/ariadne_codegen/client_generators/custom_operation.py new file mode 100644 index 00000000..c0d9b2e7 --- /dev/null +++ b/ariadne_codegen/client_generators/custom_operation.py @@ -0,0 +1,219 @@ +import ast +from typing import Dict, List, Optional, Tuple, Union, cast + +from graphql import ( + GraphQLEnumType, + GraphQLFieldMap, + GraphQLInputObjectType, + GraphQLInterfaceType, + GraphQLObjectType, + GraphQLScalarType, + GraphQLUnionType, +) + +from ariadne_codegen.exceptions import ParsingError +from ariadne_codegen.utils import str_to_snake_case + +from ..codegen import ( + generate_annotation_name, + generate_arg, + generate_arguments, + generate_call, + generate_class_def, + generate_constant, + generate_import_from, + generate_keyword, + generate_method_definition, + generate_module, + generate_name, + generate_return, +) +from ..plugins.manager import PluginManager +from .constants import ( + ANY, + BASE_MODEL_FILE_PATH, + CUSTOM_FIELDS_FILE_PATH, + CUSTOM_FIELDS_TYPING_FILE_PATH, + INPUT_SCALARS_MAP, + OPTIONAL, + TYPING_MODULE, + UPLOAD_CLASS_NAME, +) +from .scalars import ScalarData +from .utils import get_final_type + + +class CustomOperationGenerator: + def __init__( + self, + graphql_fields: GraphQLFieldMap, + name: str, + base_name: str, + enums_module_name: str = "enums", + custom_scalars: Optional[Dict[str, ScalarData]] = None, + plugin_manager: Optional[PluginManager] = None, + ) -> None: + self.graphql_fields = graphql_fields + self.name = name + self.base_name = base_name + self.enums_module_name = enums_module_name + self.plugin_manager = plugin_manager + self.custom_scalars = custom_scalars if custom_scalars else {} + + self._imports: List[ast.ImportFrom] = [] + self._type_imports: List[ast.ImportFrom] = [] + self._add_import(generate_import_from([OPTIONAL, ANY], TYPING_MODULE)) + + self._class_def = generate_class_def(name=name, base_names=[]) + + self._used_inputs: List[str] = [] + + def generate(self) -> ast.Module: + """Generate module with class definition of graphql client.""" + + for name, field in self.graphql_fields.items(): + final_type = get_final_type(field) + # if isinstance(final_type, GraphQLObjectType): + method_def = self._generate_method( + operation_name=name, + operation_args=field.args, + final_type=final_type, + ) + method_def.lineno = len(self._class_def.body) + 1 + self._class_def.body.append(method_def) + + if not self._class_def.body: + self._class_def.body.append(ast.Pass()) + + self._class_def.lineno = len(self._imports) + 3 + + module = generate_module( + body=cast(List[ast.stmt], self._imports) + + cast(List[ast.stmt], self._type_imports) + + [self._class_def], + ) + return module + + def _add_import(self, import_: Optional[ast.ImportFrom] = None): + if import_: + if self.plugin_manager: + import_ = self.plugin_manager.generate_client_import(import_) + if import_.names and import_.module: + self._imports.append(import_) + + def _generate_method( + self, + operation_name: str, + operation_args, + final_type, + ) -> ast.FunctionDef: + arguments = self._generate_method_arguments(operation_args) + from_ = CUSTOM_FIELDS_TYPING_FILE_PATH.stem + if isinstance(final_type, GraphQLObjectType): + return_type_name = f"{final_type.name}Fields" + from_ = CUSTOM_FIELDS_FILE_PATH.stem + elif isinstance(final_type, GraphQLInterfaceType): + return_type_name = f"{final_type.name}Interface" + from_ = CUSTOM_FIELDS_FILE_PATH.stem + elif isinstance(final_type, GraphQLUnionType): + return_type_name = f"{final_type.name}Union" + else: + return_type_name = "GraphQLField" + self._type_imports.append( + generate_import_from( + from_=from_, + names=[return_type_name], + level=1, + ) + ) + + return generate_method_definition( + name=str_to_snake_case(operation_name), + arguments=arguments, + return_type=generate_name(return_type_name), + body=[ + self._generate_return_stmt( + return_type_name, + operation_name, + operation_args, + ) + ], + decorator_list=[generate_name("classmethod")], + ) + + def _generate_method_arguments(self, operation_args): + cls_arg = generate_arg(name="cls") + kw_only_args, kw_defaults = self._generate_kw_args_and_defaults(operation_args) + return generate_arguments( + args=[cls_arg], + kwonlyargs=kw_only_args, + kw_defaults=kw_defaults, + ) + + def _generate_kw_args_and_defaults(self, operation_args): + kw_only_args = [] + kw_defaults = [] + for arg_name, arg_type in operation_args.items(): + arg_final_type = get_final_type(arg_type) + annotation, _ = self._parse_graphql_type_name(arg_final_type) + kw_only_args.append(generate_arg(name=arg_name, annotation=annotation)) + kw_defaults.append(generate_constant(value=None)) + return kw_only_args, kw_defaults + + def _generate_return_stmt(self, return_type_name, operation_name, operation_args): + keywords = [ + generate_keyword(arg=arg_name, value=generate_name(arg_name)) + for arg_name in operation_args + ] + return generate_return( + value=generate_call( + func=generate_name(return_type_name), + args=[], + keywords=[ + generate_keyword( + arg="field_name", value=generate_constant(value=operation_name) + ), + *keywords, + ], + ) + ) + + def _parse_graphql_type_name( + self, type_, nullable: bool = True + ) -> Tuple[Union[ast.Name, ast.Subscript], Optional[str]]: + name = type_.name + + used_custom_scalar = None + if isinstance(type_, GraphQLInputObjectType): + self._used_inputs.append(name) + self._add_import( + generate_import_from( + names=[name], + from_="input_types", + level=1, + ) + ) + elif isinstance(type_, GraphQLEnumType): + self._add_import(generate_import_from(names=[name], level=1)) + elif isinstance(type_, GraphQLScalarType): + if name not in self.custom_scalars: + name = INPUT_SCALARS_MAP.get(name, ANY) + if name == UPLOAD_CLASS_NAME: + self._add_import( + generate_import_from( + names=[UPLOAD_CLASS_NAME], + from_=BASE_MODEL_FILE_PATH.stem, + level=1, + ) + ) + else: + used_custom_scalar = name + name = self.custom_scalars[name].type_name + else: + raise ParsingError(f"Incorrect argument type {name}") + + return generate_annotation_name(name, nullable), used_custom_scalar + + @staticmethod + def _capitalize_first_letter(s: str) -> str: + return s[0].upper() + s[1:] diff --git a/ariadne_codegen/client_generators/dependencies/base_operation.py b/ariadne_codegen/client_generators/dependencies/base_operation.py new file mode 100644 index 00000000..a488cc73 --- /dev/null +++ b/ariadne_codegen/client_generators/dependencies/base_operation.py @@ -0,0 +1,92 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +from graphql import ( + ArgumentNode, + BooleanValueNode, + FieldNode, + FloatValueNode, + InlineFragmentNode, + IntValueNode, + NamedTypeNode, + NameNode, + ObjectFieldNode, + ObjectValueNode, + SelectionSetNode, + StringValueNode, +) + +from .base_model import BaseModel + + +class GraphQLArgument: + def __init__(self, argument_name: str, value: Any): + self._name = argument_name + self._value = self._convert_value(value) + + def _convert_value( + self, value: Any + ) -> Union[ + StringValueNode, IntValueNode, FloatValueNode, BooleanValueNode, ObjectValueNode + ]: + if isinstance(value, str): + return StringValueNode(value=value) + if isinstance(value, int): + return IntValueNode(value=str(value)) + if isinstance(value, float): + return FloatValueNode(value=str(value)) + if isinstance(value, bool): + return BooleanValueNode(value=value) + if isinstance(value, BaseModel): + fields = [ + ObjectFieldNode(name=NameNode(value=k), value=self._convert_value(v)) + for k, v in value.model_dump().items() + ] + return ObjectValueNode(fields=fields) + raise TypeError(f"Unsupported argument type: {type(value)}") + + def to_ast(self) -> ArgumentNode: + return ArgumentNode(name=NameNode(value=self._name), value=self._value) + + +class GraphQLField: + def __init__(self, field_name: str, **kwargs: Any) -> None: + self._field_name: str = field_name + self._arguments: List[GraphQLArgument] = [ + GraphQLArgument(k, v) for k, v in kwargs.items() if v + ] + self._subfields: List["GraphQLField"] = [] + self._alias: Optional[str] = None + self._inline_fragments: Dict[str, Tuple["GraphQLField", ...]] = {} + + def alias(self, alias: str) -> "GraphQLField": + self._alias = alias + return self + + def _build_field_name(self) -> str: + if self._alias: + return f"{self._alias}: {self._field_name}" + return self._field_name + + def to_ast(self) -> FieldNode: + selections: List[Union[FieldNode, InlineFragmentNode]] = [ + sub_field.to_ast() for sub_field in self._subfields + ] + if self._inline_fragments: + selections.extend( + [ + InlineFragmentNode( + type_condition=NamedTypeNode(name=NameNode(value=name)), + selection_set=SelectionSetNode( + selections=[sub_field.to_ast() for sub_field in subfields] + ), + ) + for name, subfields in self._inline_fragments.items() + ] + ) + return FieldNode( + name=NameNode(value=self._build_field_name()), + arguments=[arg.to_ast() for arg in self._arguments], + selection_set=( + SelectionSetNode(selections=selections) if selections else None + ), + ) diff --git a/ariadne_codegen/client_generators/package.py b/ariadne_codegen/client_generators/package.py index e17ec5e6..a49901c3 100644 --- a/ariadne_codegen/client_generators/package.py +++ b/ariadne_codegen/client_generators/package.py @@ -2,7 +2,12 @@ from pathlib import Path from typing import Dict, List, Optional, Set -from graphql import FragmentDefinitionNode, GraphQLSchema, OperationDefinitionNode +from graphql import ( + FragmentDefinitionNode, + GraphQLSchema, + OperationDefinitionNode, + OperationType, +) from ..codegen import generate_import_from from ..exceptions import ParsingError @@ -13,9 +18,11 @@ from .client import ClientGenerator from .comments import get_comment from .constants import ( + BASE_GRAPHQL_OPERATION_CLASS_NAME, BASE_MODEL_CLASS_NAME, BASE_MODEL_FILE_PATH, BASE_MODEL_IMPORT, + BASE_OPERATION_FILE_PATH, DEFAULT_ASYNC_BASE_CLIENT_OPEN_TELEMETRY_PATH, DEFAULT_ASYNC_BASE_CLIENT_PATH, DEFAULT_BASE_CLIENT_OPEN_TELEMETRY_PATH, @@ -26,6 +33,9 @@ UPLOAD_CLASS_NAME, UPLOAD_IMPORT, ) +from .custom_fields import CustomFieldsGenerator +from .custom_fields_typing import CustomFieldsTypingGenerator +from .custom_operation import CustomOperationGenerator from .enums import EnumsGenerator from .fragments import FragmentsGenerator from .init_file import InitFileGenerator @@ -45,6 +55,10 @@ def __init__( enums_generator: EnumsGenerator, input_types_generator: InputTypesGenerator, fragments_generator: FragmentsGenerator, + custom_fields_generator: Optional[CustomFieldsGenerator] = None, + custom_fields_typing_generator: Optional[CustomFieldsTypingGenerator] = None, + custom_query_generator: Optional[CustomOperationGenerator] = None, + custom_mutation_generator: Optional[CustomOperationGenerator] = None, fragments_definitions: Optional[Dict[str, FragmentDefinitionNode]] = None, client_name: str = "Client", async_client: bool = True, @@ -54,6 +68,7 @@ def __init__( enums_module_name: str = "enums", input_types_module_name: str = "input_types", fragments_module_name: str = "fragments", + custom_help_field_module_name: str = "custom_typing_fields", comments_strategy: CommentsStrategy = CommentsStrategy.STABLE, queries_source: str = "", schema_source: str = "", @@ -61,12 +76,14 @@ def __init__( include_all_inputs: bool = True, include_all_enums: bool = True, base_model_file_path: str = BASE_MODEL_FILE_PATH.as_posix(), + base_schema_root_file_path: str = BASE_OPERATION_FILE_PATH.as_posix(), base_model_import: ast.ImportFrom = BASE_MODEL_IMPORT, upload_import: ast.ImportFrom = UPLOAD_IMPORT, unset_import: ast.ImportFrom = UNSET_IMPORT, files_to_include: Optional[List[str]] = None, custom_scalars: Optional[Dict[str, ScalarData]] = None, plugin_manager: Optional[PluginManager] = None, + enable_custom_operations: bool = False, ) -> None: self.package_path = Path(target_path) / package_name @@ -80,6 +97,11 @@ def __init__( self.enums_generator = enums_generator self.input_types_generator = input_types_generator self.fragments_generator = fragments_generator + self.custom_fields_generator = custom_fields_generator + self.custom_query_generator = custom_query_generator + self.custom_mutation_generator = custom_mutation_generator + self.custom_fields_typing_generator = custom_fields_typing_generator + self.custom_help_field_module_name = custom_help_field_module_name self.client_name = client_name self.async_client = async_client @@ -104,6 +126,8 @@ def __init__( self.upload_import = upload_import self.unset_import = unset_import + self.base_schema_root_file_path = Path(base_schema_root_file_path) + self.files_to_include = ( [Path(f) for f in files_to_include] if files_to_include else [] ) @@ -115,6 +139,10 @@ def __init__( self._unpacked_fragments: Set[str] = set() self._used_enums: List[str] = [] + self.enable_custom_operations = enable_custom_operations + if self.enable_custom_operations: + self.files_to_include.append(self.base_schema_root_file_path) + def generate(self) -> List[str]: """Generate package with graphql client.""" self._include_exceptions() @@ -125,6 +153,21 @@ def generate(self) -> List[str]: self._generate_result_types() self._generate_fragments() self._copy_files() + if self.enable_custom_operations: + self._generate_custom_fields_typing() + self._generate_custom_fields() + self.client_generator.add_execute_custom_operation_method() + if self.custom_query_generator: + self._generate_custom_queries() + self.client_generator.create_custom_operation_method( + "query", OperationType.QUERY.value.upper() + ) + if self.custom_mutation_generator: + self._generate_custom_mutations() + self.client_generator.create_custom_operation_method( + "mutation", OperationType.MUTATION.value.upper() + ) + self._generate_client() self._generate_enums() self._generate_init() @@ -335,6 +378,39 @@ def _generate_init(self): init_file_path.write_text(code) self._generated_files.append(init_file_path.name) + def _generate_custom_queries(self): + file_path = self.package_path / "custom_queries.py" + module = self.custom_query_generator.generate() + code = self._add_comments_to_code(ast_to_str(module, False)) + file_path.write_text(code) + self._generated_files.append(file_path.name) + + def _generate_custom_mutations(self): + file_path = self.package_path / "custom_mutations.py" + module = self.custom_mutation_generator.generate() + code = self._add_comments_to_code(ast_to_str(module, False)) + file_path.write_text(code) + self._generated_files.append(file_path.name) + + def _generate_custom_fields_typing(self): + file_path = self.package_path / "custom_typing_fields.py" + module = self.custom_fields_typing_generator.generate() + code = self._add_comments_to_code(ast_to_str(module, False)) + file_path.write_text(code) + self._generated_files.append(file_path.name) + self.init_generator.add_import( + self.custom_fields_typing_generator.get_generated_public_names(), + self.custom_help_field_module_name, + 1, + ) + + def _generate_custom_fields(self): + file_path = self.package_path / "custom_fields.py" + module = self.custom_fields_generator.generate() + code = self._add_comments_to_code(ast_to_str(module, False)) + file_path.write_text(code) + self._generated_files.append(file_path.name) + def get_package_generator( schema: GraphQLSchema, @@ -384,6 +460,28 @@ def get_package_generator( custom_scalars=settings.scalars, plugin_manager=plugin_manager, ) + custom_fields_generator = CustomFieldsGenerator(schema=schema) + custom_fields_typing_generator = CustomFieldsTypingGenerator(schema=schema) + custom_query_generator = None + if schema.query_type: + custom_query_generator = CustomOperationGenerator( + graphql_fields=schema.query_type.fields, + name="Query", + base_name=BASE_GRAPHQL_OPERATION_CLASS_NAME, + enums_module_name=settings.enums_module_name, + custom_scalars=settings.scalars, + plugin_manager=plugin_manager, + ) + custom_mutation_generator = None + if schema.mutation_type: + custom_mutation_generator = CustomOperationGenerator( + graphql_fields=schema.mutation_type.fields, + name="Mutation", + base_name=BASE_GRAPHQL_OPERATION_CLASS_NAME, + enums_module_name=settings.enums_module_name, + custom_scalars=settings.scalars, + plugin_manager=plugin_manager, + ) return PackageGenerator( package_name=settings.target_package_name, @@ -403,6 +501,10 @@ def get_package_generator( enums_module_name=settings.enums_module_name, input_types_module_name=settings.input_types_module_name, fragments_module_name=settings.fragments_module_name, + custom_fields_generator=custom_fields_generator, + custom_fields_typing_generator=custom_fields_typing_generator, + custom_query_generator=custom_query_generator, + custom_mutation_generator=custom_mutation_generator, comments_strategy=settings.include_comments, queries_source=settings.queries_path, schema_source=settings.schema_source, @@ -416,4 +518,5 @@ def get_package_generator( files_to_include=settings.files_to_include, custom_scalars=settings.scalars, plugin_manager=plugin_manager, + enable_custom_operations=settings.enable_custom_operations, ) diff --git a/ariadne_codegen/client_generators/utils.py b/ariadne_codegen/client_generators/utils.py new file mode 100644 index 00000000..3540677f --- /dev/null +++ b/ariadne_codegen/client_generators/utils.py @@ -0,0 +1,56 @@ +from typing import Dict, List, Set + +from graphql import ( + GraphQLInterfaceType, + GraphQLObjectType, + GraphQLSchema, + GraphQLUnionType, +) + + +class TypeCollector: + def __init__(self, schema: GraphQLSchema): + self.schema = schema + self.collected_types: Set[str] = set() + self.visited_types: Set[str] = set() + + def collect(self) -> List[str]: + if self.schema.query_type: + self._collect_types(self.schema.query_type.fields) + if self.schema.mutation_type: + self._collect_types(self.schema.mutation_type.fields) + return sorted(self.collected_types) + + def _collect_types(self, fields: Dict[str, GraphQLObjectType]) -> None: + for field in fields.values(): + graphql_type = get_final_type(field) + self._collect_dependent_types(graphql_type) + + def _collect_dependent_types(self, graphql_type: GraphQLObjectType) -> None: + stack = [graphql_type] + + while stack: + current_type = stack.pop() + if current_type.name in self.visited_types: + continue + + self.visited_types.add(current_type.name) + self.collected_types.add(current_type.name) + + if isinstance(current_type, (GraphQLObjectType, GraphQLInterfaceType)): + for subfield in current_type.fields.values(): + subfield_type = get_final_type(subfield) + if isinstance(subfield_type, GraphQLObjectType): + stack.append(subfield_type) + elif isinstance(subfield_type, GraphQLUnionType): + stack.extend(subfield_type.types) + for interface in current_type.interfaces: + stack.append(interface) + + +def get_final_type(type_): + while hasattr(type_, "of_type"): + type_ = type_.of_type + if hasattr(type_, "type"): + return get_final_type(type_.type) + return type_ diff --git a/ariadne_codegen/codegen.py b/ariadne_codegen/codegen.py index 40a8b894..c4893511 100644 --- a/ariadne_codegen/codegen.py +++ b/ariadne_codegen/codegen.py @@ -24,8 +24,13 @@ from .exceptions import ParsingError +def generate_import(names: List[str], level: int = 0) -> ast.Import: + """Generate import statement.""" + return ast.Import(names=[ast.alias(n) for n in names], level=level) + + def generate_import_from( - names: List[str], from_: str, level: int = 0 + names: List[str], from_: Optional[str] = None, level: int = 0 ) -> ast.ImportFrom: """Generate import from statement.""" return ast.ImportFrom( @@ -34,7 +39,7 @@ def generate_import_from( def generate_nullable_annotation( - slice_: Union[ast.Name, ast.Subscript] + slice_: Union[ast.Name, ast.Subscript], ) -> ast.Subscript: """Generate optional annotation.""" return ast.Subscript(value=ast.Name(id=OPTIONAL), slice=slice_) @@ -65,15 +70,19 @@ def generate_arg( def generate_arguments( args: Optional[List[ast.arg]] = None, - defaults: Optional[List[ast.expr]] = None, + vararg: Optional[ast.arg] = None, + kwonlyargs: Optional[list[ast.arg]] = None, + kw_defaults: Optional[list[Union[ast.expr, None]]] = None, kwarg: Optional[ast.arg] = None, + defaults: Optional[List[ast.expr]] = None, ) -> ast.arguments: """Generate arguments.""" return ast.arguments( posonlyargs=[], args=args if args else [], - kwonlyargs=[], - kw_defaults=[], + vararg=vararg, + kwonlyargs=kwonlyargs if kwonlyargs else [], + kw_defaults=kw_defaults if kw_defaults else [], kwarg=kwarg, defaults=defaults or [], ) @@ -85,13 +94,14 @@ def generate_async_method_definition( return_type: Union[ast.Name, ast.Subscript], body: Optional[List[ast.stmt]] = None, lineno: int = 1, + decorator_list: Optional[List[ast.Name]] = None, ) -> ast.AsyncFunctionDef: """Generate async function.""" return ast.AsyncFunctionDef( name=name, args=arguments, body=body if body else [ast.Pass()], - decorator_list=[], + decorator_list=decorator_list if decorator_list else [], returns=return_type, lineno=lineno, ) @@ -118,11 +128,26 @@ def generate_name(name: str) -> ast.Name: return ast.Name(id=name) +def generate_joined_str(values: list[ast.expr]) -> ast.JoinedStr: + """Generate joined str object.""" + return ast.JoinedStr(values) + + def generate_constant(value: Any) -> ast.Constant: """Generate constant object.""" return ast.Constant(value=value) +def generate_formatted_value( + value: ast.expr, conversion: int = -1, format_spec: Optional[ast.expr] = None +) -> ast.FormattedValue: + return ast.FormattedValue( + value=value, + conversion=conversion, + format_spec=format_spec, + ) + + def generate_assign( targets: List[str], value: Union[ast.expr, List[ast.expr]], lineno: int = 1 ) -> ast.Assign: @@ -267,6 +292,25 @@ def generate_list(elements: List[Optional[ast.expr]]) -> ast.List: return ast.List(elts=elements) +def generate_list_comp( + elt: ast.expr, generators: list[ast.comprehension] +) -> ast.ListComp: + """Generate list comprehension""" + return ast.ListComp(elt=elt, generators=generators) + + +def generate_comp( + target: str, iter_: str, ifs: Optional[List[ast.expr]] = None, is_async: int = 0 +) -> ast.comprehension: + "Generate comprehension" + return ast.comprehension( + target=generate_name(target), + iter=generate_name(iter_), + ifs=ifs if ifs else [], + is_async=is_async, + ) + + def generate_lambda(body: ast.expr, args: Optional[ast.arguments] = None) -> ast.Lambda: """Generate lambda definition.""" return ast.Lambda(args=args or generate_arguments(), body=body) @@ -299,12 +343,13 @@ def generate_method_definition( return_type: Union[ast.Name, ast.Subscript], body: Optional[List[ast.stmt]] = None, lineno: int = 1, + decorator_list: Optional[List[ast.Name]] = None, ) -> ast.FunctionDef: return ast.FunctionDef( name=name, args=arguments, body=body if body else [ast.Pass()], - decorator_list=[], + decorator_list=decorator_list if decorator_list else [], returns=return_type, lineno=lineno, ) diff --git a/ariadne_codegen/main.py b/ariadne_codegen/main.py index 57eb60c5..2dcd603c 100644 --- a/ariadne_codegen/main.py +++ b/ariadne_codegen/main.py @@ -61,9 +61,12 @@ def client(config_dict): schema = plugin_manager.process_schema(schema) assert_valid_schema(schema) - definitions = get_graphql_queries(settings.queries_path, schema) - queries = filter_operations_definitions(definitions) - fragments = filter_fragments_definitions(definitions) + fragments = [] + queries = [] + if settings.queries_path: + definitions = get_graphql_queries(settings.queries_path, schema) + queries = filter_operations_definitions(definitions) + fragments = filter_fragments_definitions(definitions) sys.stdout.write(settings.used_settings_message) diff --git a/ariadne_codegen/settings.py b/ariadne_codegen/settings.py index fbd8f427..808397ba 100644 --- a/ariadne_codegen/settings.py +++ b/ariadne_codegen/settings.py @@ -37,6 +37,7 @@ class BaseSettings: remote_schema_url: str = "" remote_schema_headers: dict = field(default_factory=dict) remote_schema_verify_ssl: bool = True + enable_custom_operations: bool = False plugins: List[str] = field(default_factory=list) def __post_init__(self): @@ -73,7 +74,7 @@ class ClientSettings(BaseSettings): scalars: Dict[str, ScalarData] = field(default_factory=dict) def __post_init__(self): - if not self.queries_path: + if not self.queries_path and not self.enable_custom_operations: raise TypeError("__init__ missing 1 required argument: 'queries_path'") super().__post_init__() diff --git a/tests/main/clients/custom_query_builder/expected_client/__init__.py b/tests/main/clients/custom_query_builder/expected_client/__init__.py new file mode 100644 index 00000000..9fb8977c --- /dev/null +++ b/tests/main/clients/custom_query_builder/expected_client/__init__.py @@ -0,0 +1,56 @@ +from .async_base_client import AsyncBaseClient +from .base_model import BaseModel, Upload +from .client import Client +from .custom_typing_fields import ( + AppGraphQLField, + CollectionTranslatableContentGraphQLField, + MetadataErrorGraphQLField, + MetadataItemGraphQLField, + ObjectWithMetadataGraphQLField, + PageInfoGraphQLField, + ProductCountableConnectionGraphQLField, + ProductCountableEdgeGraphQLField, + ProductGraphQLField, + ProductTranslatableContentGraphQLField, + ProductTypeCountableConnectionGraphQLField, + TranslatableItemConnectionGraphQLField, + TranslatableItemEdgeGraphQLField, + TranslatableItemUnion, + UpdateMetadataGraphQLField, +) +from .enums import MetadataErrorCode +from .exceptions import ( + GraphQLClientError, + GraphQLClientGraphQLError, + GraphQLClientGraphQLMultiError, + GraphQLClientHttpError, + GraphQLClientInvalidResponseError, +) + +__all__ = [ + "AppGraphQLField", + "AsyncBaseClient", + "BaseModel", + "Client", + "CollectionTranslatableContentGraphQLField", + "GraphQLClientError", + "GraphQLClientGraphQLError", + "GraphQLClientGraphQLMultiError", + "GraphQLClientHttpError", + "GraphQLClientInvalidResponseError", + "MetadataErrorCode", + "MetadataErrorGraphQLField", + "MetadataItemGraphQLField", + "ObjectWithMetadataGraphQLField", + "PageInfoGraphQLField", + "ProductCountableConnectionGraphQLField", + "ProductCountableEdgeGraphQLField", + "ProductGraphQLField", + "ProductTranslatableContentGraphQLField", + "ProductTypeCountableConnectionGraphQLField", + "TranslatableItemConnectionGraphQLField", + "TranslatableItemEdgeGraphQLField", + "TranslatableItemUnion", + "UpdateMetadataGraphQLField", + "Upload", +] diff --git a/tests/main/clients/custom_query_builder/expected_client/async_base_client.py b/tests/main/clients/custom_query_builder/expected_client/async_base_client.py new file mode 100644 index 00000000..5358ced6 --- /dev/null +++ b/tests/main/clients/custom_query_builder/expected_client/async_base_client.py @@ -0,0 +1,370 @@ +import enum +import json +from typing import IO, Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, cast +from uuid import uuid4 + +import httpx +from pydantic import BaseModel +from pydantic_core import to_jsonable_python + +from .base_model import UNSET, Upload +from .exceptions import ( + GraphQLClientGraphQLMultiError, + GraphQLClientHttpError, + GraphQLClientInvalidMessageFormat, + GraphQLClientInvalidResponseError, +) + +try: + from websockets.client import ( # type: ignore[import-not-found,unused-ignore] + WebSocketClientProtocol, + connect as ws_connect, + ) + from websockets.typing import ( # type: ignore[import-not-found,unused-ignore] + Data, + Origin, + Subprotocol, + ) +except ImportError: + from contextlib import asynccontextmanager + + @asynccontextmanager # type: ignore + async def ws_connect(*args, **kwargs): # pylint: disable=unused-argument + raise NotImplementedError("Subscriptions require 'websockets' package.") + yield # pylint: disable=unreachable + + WebSocketClientProtocol = Any # type: ignore[misc,assignment,unused-ignore] + Data = Any # type: ignore[misc,assignment,unused-ignore] + Origin = Any # type: ignore[misc,assignment,unused-ignore] + + def Subprotocol(*args, **kwargs): # type: ignore # pylint: disable=invalid-name + raise NotImplementedError("Subscriptions require 'websockets' package.") + + +Self = TypeVar("Self", bound="AsyncBaseClient") + +GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" + + +class GraphQLTransportWSMessageType(str, enum.Enum): + CONNECTION_INIT = "connection_init" + CONNECTION_ACK = "connection_ack" + PING = "ping" + PONG = "pong" + SUBSCRIBE = "subscribe" + NEXT = "next" + ERROR = "error" + COMPLETE = "complete" + + +class AsyncBaseClient: + def __init__( + self, + url: str = "", + headers: Optional[Dict[str, str]] = None, + http_client: Optional[httpx.AsyncClient] = None, + ws_url: str = "", + ws_headers: Optional[Dict[str, Any]] = None, + ws_origin: Optional[str] = None, + ws_connection_init_payload: Optional[Dict[str, Any]] = None, + ) -> None: + self.url = url + self.headers = headers + self.http_client = ( + http_client if http_client else httpx.AsyncClient(headers=headers) + ) + + self.ws_url = ws_url + self.ws_headers = ws_headers or {} + self.ws_origin = Origin(ws_origin) if ws_origin else None + self.ws_connection_init_payload = ws_connection_init_payload + + async def __aenter__(self: Self) -> Self: + return self + + async def __aexit__( + self, + exc_type: object, + exc_val: object, + exc_tb: object, + ) -> None: + await self.http_client.aclose() + + async def execute( + self, + query: str, + operation_name: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> httpx.Response: + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart( + query=query, + operation_name=operation_name, + variables=processed_variables, + files=files, + files_map=files_map, + **kwargs, + ) + + return await self._execute_json( + query=query, + operation_name=operation_name, + variables=processed_variables, + **kwargs, + ) + + def get_data(self, response: httpx.Response) -> Dict[str, Any]: + if not response.is_success: + raise GraphQLClientHttpError( + status_code=response.status_code, response=response + ) + + try: + response_json = response.json() + except ValueError as exc: + raise GraphQLClientInvalidResponseError(response=response) from exc + + if (not isinstance(response_json, dict)) or ( + "data" not in response_json and "errors" not in response_json + ): + raise GraphQLClientInvalidResponseError(response=response) + + data = response_json.get("data") + errors = response_json.get("errors") + + if errors: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=errors, data=data + ) + + return cast(Dict[str, Any], data) + + async def execute_ws( + self, + query: str, + operation_name: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> AsyncIterator[Dict[str, Any]]: + headers = self.ws_headers.copy() + headers.update(kwargs.get("extra_headers", {})) + + merged_kwargs: Dict[str, Any] = {"origin": self.ws_origin} + merged_kwargs.update(kwargs) + merged_kwargs["extra_headers"] = headers + + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + **merged_kwargs, + ) as websocket: + await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) + await self._send_subscribe( + websocket, + operation_id=operation_id, + query=query, + operation_name=operation_name, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message(message, websocket) + if data: + yield data + + def _process_variables( + self, variables: Optional[Dict[str, Any]] + ) -> Tuple[ + Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]] + ]: + if not variables: + return {}, {}, {} + + serializable_variables = self._convert_dict_to_json_serializable(variables) + return self._get_files_from_variables(serializable_variables) + + def _convert_dict_to_json_serializable( + self, dict_: Dict[str, Any] + ) -> Dict[str, Any]: + return { + key: self._convert_value(value) + for key, value in dict_.items() + if value is not UNSET + } + + def _convert_value(self, value: Any) -> Any: + if isinstance(value, BaseModel): + return value.model_dump(by_alias=True, exclude_unset=True) + if isinstance(value, list): + return [self._convert_value(item) for item in value] + return value + + def _get_files_from_variables( + self, variables: Dict[str, Any] + ) -> Tuple[ + Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]] + ]: + files_map: Dict[str, List[str]] = {} + files_list: List[Upload] = [] + + def separate_files(path: str, obj: Any) -> Any: + if isinstance(obj, list): + nulled_list = [] + for index, value in enumerate(obj): + value = separate_files(f"{path}.{index}", value) + nulled_list.append(value) + return nulled_list + + if isinstance(obj, dict): + nulled_dict = {} + for key, value in obj.items(): + value = separate_files(f"{path}.{key}", value) + nulled_dict[key] = value + return nulled_dict + + if isinstance(obj, Upload): + if obj in files_list: + file_index = files_list.index(obj) + files_map[str(file_index)].append(path) + else: + file_index = len(files_list) + files_list.append(obj) + files_map[str(file_index)] = [path] + return None + + return obj + + nulled_variables = separate_files("variables", variables) + files: Dict[str, Tuple[str, IO[bytes], str]] = { + str(i): (file_.filename, cast(IO[bytes], file_.content), file_.content_type) + for i, file_ in enumerate(files_list) + } + return nulled_variables, files, files_map + + async def _execute_multipart( + self, + query: str, + operation_name: Optional[str], + variables: Dict[str, Any], + files: Dict[str, Tuple[str, IO[bytes], str]], + files_map: Dict[str, List[str]], + **kwargs: Any, + ) -> httpx.Response: + data = { + "operations": json.dumps( + { + "query": query, + "operationName": operation_name, + "variables": variables, + }, + default=to_jsonable_python, + ), + "map": json.dumps(files_map, default=to_jsonable_python), + } + + return await self.http_client.post( + url=self.url, data=data, files=files, **kwargs + ) + + async def _execute_json( + self, + query: str, + operation_name: Optional[str], + variables: Dict[str, Any], + **kwargs: Any, + ) -> httpx.Response: + headers: Dict[str, str] = {"Content-Type": "application/json"} + headers.update(kwargs.get("headers", {})) + + merged_kwargs: Dict[str, Any] = kwargs.copy() + merged_kwargs["headers"] = headers + + return await self.http_client.post( + url=self.url, + content=json.dumps( + { + "query": query, + "operationName": operation_name, + "variables": variables, + }, + default=to_jsonable_python, + ), + **merged_kwargs, + ) + + async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: + payload: Dict[str, Any] = { + "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value + } + if self.ws_connection_init_payload: + payload["payload"] = self.ws_connection_init_payload + await websocket.send(json.dumps(payload)) + + async def _send_subscribe( + self, + websocket: WebSocketClientProtocol, + operation_id: str, + query: str, + operation_name: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + ) -> None: + payload: Dict[str, Any] = { + "id": operation_id, + "type": GraphQLTransportWSMessageType.SUBSCRIBE.value, + "payload": {"query": query, "operationName": operation_name}, + } + if variables: + payload["payload"]["variables"] = self._convert_dict_to_json_serializable( + variables + ) + await websocket.send(json.dumps(payload)) + + async def _handle_ws_message( + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: Optional[GraphQLTransportWSMessageType] = None, + ) -> Optional[Dict[str, Any]]: + try: + message_dict = json.loads(message) + except json.JSONDecodeError as exc: + raise GraphQLClientInvalidMessageFormat(message=message) from exc + + type_ = message_dict.get("type") + payload = message_dict.get("payload", {}) + + if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: + raise GraphQLClientInvalidMessageFormat(message=message) + + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message received. Expected: {expected_type.value}" + ) + + if type_ == GraphQLTransportWSMessageType.NEXT: + if "data" not in payload: + raise GraphQLClientInvalidMessageFormat(message=message) + return cast(Dict[str, Any], payload["data"]) + + if type_ == GraphQLTransportWSMessageType.COMPLETE: + await websocket.close() + elif type_ == GraphQLTransportWSMessageType.PING: + await websocket.send( + json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) + ) + elif type_ == GraphQLTransportWSMessageType.ERROR: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=payload, data=message_dict + ) + + return None diff --git a/tests/main/clients/custom_query_builder/expected_client/base_model.py b/tests/main/clients/custom_query_builder/expected_client/base_model.py new file mode 100644 index 00000000..ccde3975 --- /dev/null +++ b/tests/main/clients/custom_query_builder/expected_client/base_model.py @@ -0,0 +1,27 @@ +from io import IOBase + +from pydantic import BaseModel as PydanticBaseModel, ConfigDict + + +class UnsetType: + def __bool__(self) -> bool: + return False + + +UNSET = UnsetType() + + +class BaseModel(PydanticBaseModel): + model_config = ConfigDict( + populate_by_name=True, + validate_assignment=True, + arbitrary_types_allowed=True, + protected_namespaces=(), + ) + + +class Upload: + def __init__(self, filename: str, content: IOBase, content_type: str): + self.filename = filename + self.content = content + self.content_type = content_type diff --git a/tests/main/clients/custom_query_builder/expected_client/base_operation.py b/tests/main/clients/custom_query_builder/expected_client/base_operation.py new file mode 100644 index 00000000..a488cc73 --- /dev/null +++ b/tests/main/clients/custom_query_builder/expected_client/base_operation.py @@ -0,0 +1,92 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +from graphql import ( + ArgumentNode, + BooleanValueNode, + FieldNode, + FloatValueNode, + InlineFragmentNode, + IntValueNode, + NamedTypeNode, + NameNode, + ObjectFieldNode, + ObjectValueNode, + SelectionSetNode, + StringValueNode, +) + +from .base_model import BaseModel + + +class GraphQLArgument: + def __init__(self, argument_name: str, value: Any): + self._name = argument_name + self._value = self._convert_value(value) + + def _convert_value( + self, value: Any + ) -> Union[ + StringValueNode, IntValueNode, FloatValueNode, BooleanValueNode, ObjectValueNode + ]: + if isinstance(value, str): + return StringValueNode(value=value) + if isinstance(value, int): + return IntValueNode(value=str(value)) + if isinstance(value, float): + return FloatValueNode(value=str(value)) + if isinstance(value, bool): + return BooleanValueNode(value=value) + if isinstance(value, BaseModel): + fields = [ + ObjectFieldNode(name=NameNode(value=k), value=self._convert_value(v)) + for k, v in value.model_dump().items() + ] + return ObjectValueNode(fields=fields) + raise TypeError(f"Unsupported argument type: {type(value)}") + + def to_ast(self) -> ArgumentNode: + return ArgumentNode(name=NameNode(value=self._name), value=self._value) + + +class GraphQLField: + def __init__(self, field_name: str, **kwargs: Any) -> None: + self._field_name: str = field_name + self._arguments: List[GraphQLArgument] = [ + GraphQLArgument(k, v) for k, v in kwargs.items() if v + ] + self._subfields: List["GraphQLField"] = [] + self._alias: Optional[str] = None + self._inline_fragments: Dict[str, Tuple["GraphQLField", ...]] = {} + + def alias(self, alias: str) -> "GraphQLField": + self._alias = alias + return self + + def _build_field_name(self) -> str: + if self._alias: + return f"{self._alias}: {self._field_name}" + return self._field_name + + def to_ast(self) -> FieldNode: + selections: List[Union[FieldNode, InlineFragmentNode]] = [ + sub_field.to_ast() for sub_field in self._subfields + ] + if self._inline_fragments: + selections.extend( + [ + InlineFragmentNode( + type_condition=NamedTypeNode(name=NameNode(value=name)), + selection_set=SelectionSetNode( + selections=[sub_field.to_ast() for sub_field in subfields] + ), + ) + for name, subfields in self._inline_fragments.items() + ] + ) + return FieldNode( + name=NameNode(value=self._build_field_name()), + arguments=[arg.to_ast() for arg in self._arguments], + selection_set=( + SelectionSetNode(selections=selections) if selections else None + ), + ) diff --git a/tests/main/clients/custom_query_builder/expected_client/client.py b/tests/main/clients/custom_query_builder/expected_client/client.py new file mode 100644 index 00000000..4e47fc3b --- /dev/null +++ b/tests/main/clients/custom_query_builder/expected_client/client.py @@ -0,0 +1,52 @@ +from typing import Any, Dict + +from graphql import ( + DocumentNode, + NameNode, + OperationDefinitionNode, + OperationType, + SelectionSetNode, + print_ast, +) + +from .async_base_client import AsyncBaseClient +from .base_operation import GraphQLField + + +def gql(q: str) -> str: + return q + + +class Client(AsyncBaseClient): + async def execute_custom_operation( + self, *fields: GraphQLField, operation_type: OperationType, operation_name: str + ) -> Dict[str, Any]: + operation_ast = DocumentNode( + definitions=[ + OperationDefinitionNode( + operation=operation_type, + name=NameNode(value=operation_name), + selection_set=SelectionSetNode( + selections=[field.to_ast() for field in fields] + ), + ) + ] + ) + response = await self.execute( + print_ast(operation_ast), operation_name=operation_name + ) + return self.get_data(response) + + async def query(self, *fields: GraphQLField, operation_name: str) -> Dict[str, Any]: + return await self.execute_custom_operation( + *fields, operation_type=OperationType.QUERY, operation_name=operation_name + ) + + async def mutation( + self, *fields: GraphQLField, operation_name: str + ) -> Dict[str, Any]: + return await self.execute_custom_operation( + *fields, + operation_type=OperationType.MUTATION, + operation_name=operation_name + ) diff --git a/tests/main/clients/custom_query_builder/expected_client/custom_fields.py b/tests/main/clients/custom_query_builder/expected_client/custom_fields.py new file mode 100644 index 00000000..7bb33e07 --- /dev/null +++ b/tests/main/clients/custom_query_builder/expected_client/custom_fields.py @@ -0,0 +1,292 @@ +from typing import Optional, Union + +from . import ( + AppGraphQLField, + CollectionTranslatableContentGraphQLField, + MetadataErrorGraphQLField, + MetadataItemGraphQLField, + ObjectWithMetadataGraphQLField, + PageInfoGraphQLField, + ProductCountableConnectionGraphQLField, + ProductCountableEdgeGraphQLField, + ProductGraphQLField, + ProductTranslatableContentGraphQLField, + ProductTypeCountableConnectionGraphQLField, + TranslatableItemConnectionGraphQLField, + TranslatableItemEdgeGraphQLField, + TranslatableItemUnion, + UpdateMetadataGraphQLField, +) +from .base_operation import GraphQLField + + +class AppFields(GraphQLField): + id: AppGraphQLField = AppGraphQLField("id") + + def fields(self, *subfields: AppGraphQLField) -> "AppFields": + self._subfields.extend(subfields) + return self + + +class CollectionTranslatableContentFields(GraphQLField): + id: CollectionTranslatableContentGraphQLField = ( + CollectionTranslatableContentGraphQLField("id") + ) + collection_id: CollectionTranslatableContentGraphQLField = ( + CollectionTranslatableContentGraphQLField("collectionId") + ) + seo_title: CollectionTranslatableContentGraphQLField = ( + CollectionTranslatableContentGraphQLField("seoTitle") + ) + seo_description: CollectionTranslatableContentGraphQLField = ( + CollectionTranslatableContentGraphQLField("seoDescription") + ) + name: CollectionTranslatableContentGraphQLField = ( + CollectionTranslatableContentGraphQLField("name") + ) + description: CollectionTranslatableContentGraphQLField = ( + CollectionTranslatableContentGraphQLField("description") + ) + + def fields( + self, *subfields: CollectionTranslatableContentGraphQLField + ) -> "CollectionTranslatableContentFields": + self._subfields.extend(subfields) + return self + + +class MetadataErrorFields(GraphQLField): + field: MetadataErrorGraphQLField = MetadataErrorGraphQLField("field") + message: MetadataErrorGraphQLField = MetadataErrorGraphQLField("message") + code: MetadataErrorGraphQLField = MetadataErrorGraphQLField("code") + + def fields(self, *subfields: MetadataErrorGraphQLField) -> "MetadataErrorFields": + self._subfields.extend(subfields) + return self + + +class MetadataItemFields(GraphQLField): + key: MetadataItemGraphQLField = MetadataItemGraphQLField("key") + value: MetadataItemGraphQLField = MetadataItemGraphQLField("value") + + def fields(self, *subfields: MetadataItemGraphQLField) -> "MetadataItemFields": + self._subfields.extend(subfields) + return self + + +class ObjectWithMetadataInterface(GraphQLField): + @classmethod + def private_metadata(cls) -> "MetadataItemFields": + return MetadataItemFields("private_metadata") + + @classmethod + def private_metafield( + cls, *, key: Optional[str] = None + ) -> "ObjectWithMetadataGraphQLField": + return ObjectWithMetadataGraphQLField("private_metafield", key=key) + + @classmethod + def metadata(cls) -> "MetadataItemFields": + return MetadataItemFields("metadata") + + @classmethod + def metafield( + cls, *, key: Optional[str] = None + ) -> "ObjectWithMetadataGraphQLField": + return ObjectWithMetadataGraphQLField("metafield", key=key) + + def fields( + self, *subfields: Union[ObjectWithMetadataGraphQLField, "MetadataItemFields"] + ) -> "ObjectWithMetadataInterface": + self._subfields.extend(subfields) + return self + + def on( + self, type_name: str, *subfields: GraphQLField + ) -> "ObjectWithMetadataInterface": + self._inline_fragments[type_name] = subfields + return self + + +class PageInfoFields(GraphQLField): + has_next_page: PageInfoGraphQLField = PageInfoGraphQLField("hasNextPage") + has_previous_page: PageInfoGraphQLField = PageInfoGraphQLField("hasPreviousPage") + start_cursor: PageInfoGraphQLField = PageInfoGraphQLField("startCursor") + end_cursor: PageInfoGraphQLField = PageInfoGraphQLField("endCursor") + + def fields(self, *subfields: PageInfoGraphQLField) -> "PageInfoFields": + self._subfields.extend(subfields) + return self + + +class ProductFields(GraphQLField): + id: ProductGraphQLField = ProductGraphQLField("id") + slug: ProductGraphQLField = ProductGraphQLField("slug") + name: ProductGraphQLField = ProductGraphQLField("name") + + @classmethod + def private_metadata(cls) -> "MetadataItemFields": + return MetadataItemFields("private_metadata") + + @classmethod + def private_metafield(cls, *, key: Optional[str] = None) -> "ProductGraphQLField": + return ProductGraphQLField("private_metafield", key=key) + + @classmethod + def metadata(cls) -> "MetadataItemFields": + return MetadataItemFields("metadata") + + @classmethod + def metafield(cls, *, key: Optional[str] = None) -> "ProductGraphQLField": + return ProductGraphQLField("metafield", key=key) + + def fields( + self, *subfields: Union[ProductGraphQLField, "MetadataItemFields"] + ) -> "ProductFields": + self._subfields.extend(subfields) + return self + + +class ProductCountableConnectionFields(GraphQLField): + @classmethod + def edges(cls) -> "ProductCountableEdgeFields": + return ProductCountableEdgeFields("edges") + + @classmethod + def page_info(cls) -> "PageInfoFields": + return PageInfoFields("page_info") + + total_count: ProductCountableConnectionGraphQLField = ( + ProductCountableConnectionGraphQLField("totalCount") + ) + + def fields( + self, + *subfields: Union[ + ProductCountableConnectionGraphQLField, + "PageInfoFields", + "ProductCountableEdgeFields", + ] + ) -> "ProductCountableConnectionFields": + self._subfields.extend(subfields) + return self + + +class ProductCountableEdgeFields(GraphQLField): + @classmethod + def node(cls) -> "ProductFields": + return ProductFields("node") + + cursor: ProductCountableEdgeGraphQLField = ProductCountableEdgeGraphQLField( + "cursor" + ) + + def fields( + self, *subfields: Union[ProductCountableEdgeGraphQLField, "ProductFields"] + ) -> "ProductCountableEdgeFields": + self._subfields.extend(subfields) + return self + + +class ProductTranslatableContentFields(GraphQLField): + id: ProductTranslatableContentGraphQLField = ProductTranslatableContentGraphQLField( + "id" + ) + product_id: ProductTranslatableContentGraphQLField = ( + ProductTranslatableContentGraphQLField("productId") + ) + seo_title: ProductTranslatableContentGraphQLField = ( + ProductTranslatableContentGraphQLField("seoTitle") + ) + seo_description: ProductTranslatableContentGraphQLField = ( + ProductTranslatableContentGraphQLField("seoDescription") + ) + name: ProductTranslatableContentGraphQLField = ( + ProductTranslatableContentGraphQLField("name") + ) + description: ProductTranslatableContentGraphQLField = ( + ProductTranslatableContentGraphQLField("description") + ) + + def fields( + self, *subfields: ProductTranslatableContentGraphQLField + ) -> "ProductTranslatableContentFields": + self._subfields.extend(subfields) + return self + + +class ProductTypeCountableConnectionFields(GraphQLField): + @classmethod + def page_info(cls) -> "PageInfoFields": + return PageInfoFields("page_info") + + def fields( + self, + *subfields: Union[ProductTypeCountableConnectionGraphQLField, "PageInfoFields"] + ) -> "ProductTypeCountableConnectionFields": + self._subfields.extend(subfields) + return self + + +class TranslatableItemConnectionFields(GraphQLField): + @classmethod + def page_info(cls) -> "PageInfoFields": + return PageInfoFields("page_info") + + @classmethod + def edges(cls) -> "TranslatableItemEdgeFields": + return TranslatableItemEdgeFields("edges") + + total_count: TranslatableItemConnectionGraphQLField = ( + TranslatableItemConnectionGraphQLField("totalCount") + ) + + def fields( + self, + *subfields: Union[ + TranslatableItemConnectionGraphQLField, + "PageInfoFields", + "TranslatableItemEdgeFields", + ] + ) -> "TranslatableItemConnectionFields": + self._subfields.extend(subfields) + return self + + +class TranslatableItemEdgeFields(GraphQLField): + node: TranslatableItemUnion = TranslatableItemUnion("node") + cursor: TranslatableItemEdgeGraphQLField = TranslatableItemEdgeGraphQLField( + "cursor" + ) + + def fields( + self, + *subfields: Union[TranslatableItemEdgeGraphQLField, "TranslatableItemUnion"] + ) -> "TranslatableItemEdgeFields": + self._subfields.extend(subfields) + return self + + +class UpdateMetadataFields(GraphQLField): + @classmethod + def metadata_errors(cls) -> "MetadataErrorFields": + return MetadataErrorFields("metadata_errors") + + @classmethod + def errors(cls) -> "MetadataErrorFields": + return MetadataErrorFields("errors") + + @classmethod + def item(cls) -> "ObjectWithMetadataInterface": + return ObjectWithMetadataInterface("item") + + def fields( + self, + *subfields: Union[ + UpdateMetadataGraphQLField, + "MetadataErrorFields", + "ObjectWithMetadataInterface", + ] + ) -> "UpdateMetadataFields": + self._subfields.extend(subfields) + return self diff --git a/tests/main/clients/custom_query_builder/expected_client/custom_mutations.py b/tests/main/clients/custom_query_builder/expected_client/custom_mutations.py new file mode 100644 index 00000000..f21dfa8b --- /dev/null +++ b/tests/main/clients/custom_query_builder/expected_client/custom_mutations.py @@ -0,0 +1,9 @@ +from typing import Any, Optional + +from .custom_fields import UpdateMetadataFields + + +class Mutation: + @classmethod + def update_metadata(cls, *, id: Optional[str] = None) -> UpdateMetadataFields: + return UpdateMetadataFields(field_name="updateMetadata", id=id) diff --git a/tests/main/clients/custom_query_builder/expected_client/custom_queries.py b/tests/main/clients/custom_query_builder/expected_client/custom_queries.py new file mode 100644 index 00000000..d45f66d1 --- /dev/null +++ b/tests/main/clients/custom_query_builder/expected_client/custom_queries.py @@ -0,0 +1,43 @@ +from typing import Any, Optional + +from .custom_fields import ( + AppFields, + ProductCountableConnectionFields, + ProductTypeCountableConnectionFields, + TranslatableItemConnectionFields, +) + + +class Query: + @classmethod + def products( + cls, *, channel: Optional[str] = None, first: Optional[int] = None + ) -> ProductCountableConnectionFields: + return ProductCountableConnectionFields( + field_name="products", channel=channel, first=first + ) + + @classmethod + def app(cls) -> AppFields: + return AppFields(field_name="app") + + @classmethod + def product_types(cls) -> ProductTypeCountableConnectionFields: + return ProductTypeCountableConnectionFields(field_name="productTypes") + + @classmethod + def translations( + cls, + *, + before: Optional[str] = None, + after: Optional[str] = None, + first: Optional[int] = None, + last: Optional[int] = None + ) -> TranslatableItemConnectionFields: + return TranslatableItemConnectionFields( + field_name="translations", + before=before, + after=after, + first=first, + last=last, + ) diff --git a/tests/main/clients/custom_query_builder/expected_client/custom_typing_fields.py b/tests/main/clients/custom_query_builder/expected_client/custom_typing_fields.py new file mode 100644 index 00000000..8d8f7d7d --- /dev/null +++ b/tests/main/clients/custom_query_builder/expected_client/custom_typing_fields.py @@ -0,0 +1,63 @@ +from .base_operation import GraphQLField + + +class ProductGraphQLField(GraphQLField): + pass + + +class ProductCountableEdgeGraphQLField(GraphQLField): + pass + + +class ProductCountableConnectionGraphQLField(GraphQLField): + pass + + +class AppGraphQLField(GraphQLField): + pass + + +class ProductTypeCountableConnectionGraphQLField(GraphQLField): + pass + + +class PageInfoGraphQLField(GraphQLField): + pass + + +class ObjectWithMetadataGraphQLField(GraphQLField): + pass + + +class MetadataItemGraphQLField(GraphQLField): + pass + + +class UpdateMetadataGraphQLField(GraphQLField): + pass + + +class MetadataErrorGraphQLField(GraphQLField): + pass + + +class TranslatableItemConnectionGraphQLField(GraphQLField): + pass + + +class TranslatableItemEdgeGraphQLField(GraphQLField): + pass + + +class TranslatableItemUnion(GraphQLField): + def on(self, type_name: str, *subfields: GraphQLField) -> "TranslatableItemUnion": + self._inline_fragments[type_name] = subfields + return self + + +class ProductTranslatableContentGraphQLField(GraphQLField): + pass + + +class CollectionTranslatableContentGraphQLField(GraphQLField): + pass diff --git a/tests/main/clients/custom_query_builder/expected_client/enums.py b/tests/main/clients/custom_query_builder/expected_client/enums.py new file mode 100644 index 00000000..b6a853d5 --- /dev/null +++ b/tests/main/clients/custom_query_builder/expected_client/enums.py @@ -0,0 +1,9 @@ +from enum import Enum + + +class MetadataErrorCode(str, Enum): + GRAPHQL_ERROR = "GRAPHQL_ERROR" + INVALID = "INVALID" + NOT_FOUND = "NOT_FOUND" + REQUIRED = "REQUIRED" + NOT_UPDATED = "NOT_UPDATED" diff --git a/tests/main/clients/custom_query_builder/expected_client/exceptions.py b/tests/main/clients/custom_query_builder/expected_client/exceptions.py new file mode 100644 index 00000000..b34acfe1 --- /dev/null +++ b/tests/main/clients/custom_query_builder/expected_client/exceptions.py @@ -0,0 +1,83 @@ +from typing import Any, Dict, List, Optional, Union + +import httpx + + +class GraphQLClientError(Exception): + """Base exception.""" + + +class GraphQLClientHttpError(GraphQLClientError): + def __init__(self, status_code: int, response: httpx.Response) -> None: + self.status_code = status_code + self.response = response + + def __str__(self) -> str: + return f"HTTP status code: {self.status_code}" + + +class GraphQLClientInvalidResponseError(GraphQLClientError): + def __init__(self, response: httpx.Response) -> None: + self.response = response + + def __str__(self) -> str: + return "Invalid response format." + + +class GraphQLClientGraphQLError(GraphQLClientError): + def __init__( + self, + message: str, + locations: Optional[List[Dict[str, int]]] = None, + path: Optional[List[str]] = None, + extensions: Optional[Dict[str, object]] = None, + orginal: Optional[Dict[str, object]] = None, + ): + self.message = message + self.locations = locations + self.path = path + self.extensions = extensions + self.orginal = orginal + + def __str__(self) -> str: + return self.message + + @classmethod + def from_dict(cls, error: Dict[str, Any]) -> "GraphQLClientGraphQLError": + return cls( + message=error["message"], + locations=error.get("locations"), + path=error.get("path"), + extensions=error.get("extensions"), + orginal=error, + ) + + +class GraphQLClientGraphQLMultiError(GraphQLClientError): + def __init__( + self, + errors: List[GraphQLClientGraphQLError], + data: Optional[Dict[str, Any]] = None, + ): + self.errors = errors + self.data = data + + def __str__(self) -> str: + return "; ".join(str(e) for e in self.errors) + + @classmethod + def from_errors_dicts( + cls, errors_dicts: List[Dict[str, Any]], data: Optional[Dict[str, Any]] = None + ) -> "GraphQLClientGraphQLMultiError": + return cls( + errors=[GraphQLClientGraphQLError.from_dict(e) for e in errors_dicts], + data=data, + ) + + +class GraphQLClientInvalidMessageFormat(GraphQLClientError): + def __init__(self, message: Union[str, bytes]) -> None: + self.message = message + + def __str__(self) -> str: + return "Invalid message format." diff --git a/tests/main/clients/custom_query_builder/expected_client/input_types.py b/tests/main/clients/custom_query_builder/expected_client/input_types.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/main/clients/custom_query_builder/pyproject.toml b/tests/main/clients/custom_query_builder/pyproject.toml new file mode 100644 index 00000000..3cbb511e --- /dev/null +++ b/tests/main/clients/custom_query_builder/pyproject.toml @@ -0,0 +1,5 @@ +[tool.ariadne-codegen] +schema_path = "schema.graphql" +include_comments = "none" +target_package_name = "example_client" +enable_custom_operations = true diff --git a/tests/main/clients/custom_query_builder/schema.graphql b/tests/main/clients/custom_query_builder/schema.graphql new file mode 100644 index 00000000..4e873d7c --- /dev/null +++ b/tests/main/clients/custom_query_builder/schema.graphql @@ -0,0 +1,249 @@ +schema { + query: Query + mutation: Mutation +} + +type Query { + products(channel: String, first: Int): ProductCountableConnection + app: App + productTypes: ProductTypeCountableConnection + translations( + """ + Return the elements in the list that come before the specified cursor. + """ + before: String + + """ + Return the elements in the list that come after the specified cursor. + """ + after: String + + """ + Retrieve the first n elements from the list. Note that the system only allows fetching a maximum of 100 objects in a single query. + """ + first: Int + + """ + Retrieve the last n elements from the list. Note that the system only allows fetching a maximum of 100 objects in a single query. + """ + last: Int + ): TranslatableItemConnection +} + +type Mutation { + updateMetadata( + """ + ID or token (for Order and Checkout) of an object to update. + """ + id: ID! + ): UpdateMetadata +} + +type Product implements ObjectWithMetadata { + id: ID! + slug: String! + name: String! +} + +type ProductCountableEdge { + node: Product! + cursor: String! +} + +type ProductCountableConnection { + edges: [ProductCountableEdge!]! + pageInfo: PageInfo! + totalCount: Int +} + +type App { + id: ID! +} + +type ProductTypeCountableConnection { + pageInfo: PageInfo! +} + +type PageInfo { + hasNextPage: Boolean! + hasPreviousPage: Boolean! + startCursor: String + endCursor: String +} + +interface ObjectWithMetadata { + """ + List of private metadata items. Requires staff permissions to access. + """ + privateMetadata: [MetadataItem!]! + + """ + A single key from private metadata. Requires staff permissions to access. + + Tip: Use GraphQL aliases to fetch multiple keys. + """ + privateMetafield(key: String!): String + + """ + List of public metadata items. Can be accessed without permissions. + """ + metadata: [MetadataItem!]! + + """ + A single key from public metadata. + + Tip: Use GraphQL aliases to fetch multiple keys. + """ + metafield(key: String!): String +} + +type MetadataItem { + """ + Key of a metadata item. + """ + key: String! + + """ + Value of a metadata item. + """ + value: String! +} + +type UpdateMetadata { + metadataErrors: [MetadataError!]! + @deprecated( + reason: "This field will be removed in Saleor 4.0. Use `errors` field instead." + ) + errors: [MetadataError!]! + item: ObjectWithMetadata +} +type MetadataError { + """ + Name of a field that caused the error. A value of `null` indicates that the error isn't associated with a particular field. + """ + field: String + + """ + The error message. + """ + message: String + + """ + The error code. + """ + code: MetadataErrorCode! +} + +""" +An enumeration. +""" +enum MetadataErrorCode { + GRAPHQL_ERROR + INVALID + NOT_FOUND + REQUIRED + NOT_UPDATED +} + +type TranslatableItemConnection { + """ + Pagination data for this connection. + """ + pageInfo: PageInfo! + edges: [TranslatableItemEdge!]! + + """ + A total count of items in the collection. + """ + totalCount: Int +} + +type TranslatableItemEdge { + """ + The item at the end of the edge. + """ + node: TranslatableItem! + + """ + A cursor for use in pagination. + """ + cursor: String! +} + +union TranslatableItem = + ProductTranslatableContent + | CollectionTranslatableContent + +type ProductTranslatableContent @doc(category: "Products") { + """ + The ID of the product translatable content. + """ + id: ID! + + """ + The ID of the product to translate. + + Added in Saleor 3.14. + """ + productId: ID! + + """ + SEO title to translate. + """ + seoTitle: String + + """ + SEO description to translate. + """ + seoDescription: String + + """ + Product's name to translate. + """ + name: String! + + """ + Product's description to translate. + + Rich text format. For reference see https://editorjs.io/ + """ + description: JSONString +} + +type CollectionTranslatableContent @doc(category: "Products") { + """ + The ID of the collection translatable content. + """ + id: ID! + + """ + The ID of the collection to translate. + + Added in Saleor 3.14. + """ + collectionId: ID! + + """ + SEO title to translate. + """ + seoTitle: String + + """ + SEO description to translate. + """ + seoDescription: String + + """ + Collection's name to translate. + """ + name: String! + + """ + Collection's description to translate. + + Rich text format. For reference see https://editorjs.io/ + """ + description: JSONString +} + +scalar JSONString diff --git a/tests/main/custom_operation_builder/__init__.py b/tests/main/custom_operation_builder/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/main/custom_operation_builder/graphql_client/__init__.py b/tests/main/custom_operation_builder/graphql_client/__init__.py new file mode 100644 index 00000000..9251b3d8 --- /dev/null +++ b/tests/main/custom_operation_builder/graphql_client/__init__.py @@ -0,0 +1,41 @@ +from .async_base_client import AsyncBaseClient +from .base_model import BaseModel, Upload +from .client import AutoGenClient +from .custom_typing_fields import ( + AdminGraphQLField, + GuestGraphQLField, + PersonGraphQLField, + PostGraphQLField, + SearchResultUnion, + UserGraphQLField, +) +from .enums import Role +from .exceptions import ( + GraphQLClientError, + GraphQLClientGraphQLError, + GraphQLClientGraphQLMultiError, + GraphQLClientHttpError, + GraphQLClientInvalidResponseError, +) +from .input_types import AddUserInput, UpdateUserInput + +__all__ = [ + "AddUserInput", + "AdminGraphQLField", + "AsyncBaseClient", + "AutoGenClient", + "BaseModel", + "GraphQLClientError", + "GraphQLClientGraphQLError", + "GraphQLClientGraphQLMultiError", + "GraphQLClientHttpError", + "GraphQLClientInvalidResponseError", + "GuestGraphQLField", + "PersonGraphQLField", + "PostGraphQLField", + "Role", + "SearchResultUnion", + "UpdateUserInput", + "Upload", + "UserGraphQLField", +] diff --git a/tests/main/custom_operation_builder/graphql_client/async_base_client.py b/tests/main/custom_operation_builder/graphql_client/async_base_client.py new file mode 100644 index 00000000..5358ced6 --- /dev/null +++ b/tests/main/custom_operation_builder/graphql_client/async_base_client.py @@ -0,0 +1,370 @@ +import enum +import json +from typing import IO, Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, cast +from uuid import uuid4 + +import httpx +from pydantic import BaseModel +from pydantic_core import to_jsonable_python + +from .base_model import UNSET, Upload +from .exceptions import ( + GraphQLClientGraphQLMultiError, + GraphQLClientHttpError, + GraphQLClientInvalidMessageFormat, + GraphQLClientInvalidResponseError, +) + +try: + from websockets.client import ( # type: ignore[import-not-found,unused-ignore] + WebSocketClientProtocol, + connect as ws_connect, + ) + from websockets.typing import ( # type: ignore[import-not-found,unused-ignore] + Data, + Origin, + Subprotocol, + ) +except ImportError: + from contextlib import asynccontextmanager + + @asynccontextmanager # type: ignore + async def ws_connect(*args, **kwargs): # pylint: disable=unused-argument + raise NotImplementedError("Subscriptions require 'websockets' package.") + yield # pylint: disable=unreachable + + WebSocketClientProtocol = Any # type: ignore[misc,assignment,unused-ignore] + Data = Any # type: ignore[misc,assignment,unused-ignore] + Origin = Any # type: ignore[misc,assignment,unused-ignore] + + def Subprotocol(*args, **kwargs): # type: ignore # pylint: disable=invalid-name + raise NotImplementedError("Subscriptions require 'websockets' package.") + + +Self = TypeVar("Self", bound="AsyncBaseClient") + +GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" + + +class GraphQLTransportWSMessageType(str, enum.Enum): + CONNECTION_INIT = "connection_init" + CONNECTION_ACK = "connection_ack" + PING = "ping" + PONG = "pong" + SUBSCRIBE = "subscribe" + NEXT = "next" + ERROR = "error" + COMPLETE = "complete" + + +class AsyncBaseClient: + def __init__( + self, + url: str = "", + headers: Optional[Dict[str, str]] = None, + http_client: Optional[httpx.AsyncClient] = None, + ws_url: str = "", + ws_headers: Optional[Dict[str, Any]] = None, + ws_origin: Optional[str] = None, + ws_connection_init_payload: Optional[Dict[str, Any]] = None, + ) -> None: + self.url = url + self.headers = headers + self.http_client = ( + http_client if http_client else httpx.AsyncClient(headers=headers) + ) + + self.ws_url = ws_url + self.ws_headers = ws_headers or {} + self.ws_origin = Origin(ws_origin) if ws_origin else None + self.ws_connection_init_payload = ws_connection_init_payload + + async def __aenter__(self: Self) -> Self: + return self + + async def __aexit__( + self, + exc_type: object, + exc_val: object, + exc_tb: object, + ) -> None: + await self.http_client.aclose() + + async def execute( + self, + query: str, + operation_name: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> httpx.Response: + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart( + query=query, + operation_name=operation_name, + variables=processed_variables, + files=files, + files_map=files_map, + **kwargs, + ) + + return await self._execute_json( + query=query, + operation_name=operation_name, + variables=processed_variables, + **kwargs, + ) + + def get_data(self, response: httpx.Response) -> Dict[str, Any]: + if not response.is_success: + raise GraphQLClientHttpError( + status_code=response.status_code, response=response + ) + + try: + response_json = response.json() + except ValueError as exc: + raise GraphQLClientInvalidResponseError(response=response) from exc + + if (not isinstance(response_json, dict)) or ( + "data" not in response_json and "errors" not in response_json + ): + raise GraphQLClientInvalidResponseError(response=response) + + data = response_json.get("data") + errors = response_json.get("errors") + + if errors: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=errors, data=data + ) + + return cast(Dict[str, Any], data) + + async def execute_ws( + self, + query: str, + operation_name: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> AsyncIterator[Dict[str, Any]]: + headers = self.ws_headers.copy() + headers.update(kwargs.get("extra_headers", {})) + + merged_kwargs: Dict[str, Any] = {"origin": self.ws_origin} + merged_kwargs.update(kwargs) + merged_kwargs["extra_headers"] = headers + + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + **merged_kwargs, + ) as websocket: + await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) + await self._send_subscribe( + websocket, + operation_id=operation_id, + query=query, + operation_name=operation_name, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message(message, websocket) + if data: + yield data + + def _process_variables( + self, variables: Optional[Dict[str, Any]] + ) -> Tuple[ + Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]] + ]: + if not variables: + return {}, {}, {} + + serializable_variables = self._convert_dict_to_json_serializable(variables) + return self._get_files_from_variables(serializable_variables) + + def _convert_dict_to_json_serializable( + self, dict_: Dict[str, Any] + ) -> Dict[str, Any]: + return { + key: self._convert_value(value) + for key, value in dict_.items() + if value is not UNSET + } + + def _convert_value(self, value: Any) -> Any: + if isinstance(value, BaseModel): + return value.model_dump(by_alias=True, exclude_unset=True) + if isinstance(value, list): + return [self._convert_value(item) for item in value] + return value + + def _get_files_from_variables( + self, variables: Dict[str, Any] + ) -> Tuple[ + Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]] + ]: + files_map: Dict[str, List[str]] = {} + files_list: List[Upload] = [] + + def separate_files(path: str, obj: Any) -> Any: + if isinstance(obj, list): + nulled_list = [] + for index, value in enumerate(obj): + value = separate_files(f"{path}.{index}", value) + nulled_list.append(value) + return nulled_list + + if isinstance(obj, dict): + nulled_dict = {} + for key, value in obj.items(): + value = separate_files(f"{path}.{key}", value) + nulled_dict[key] = value + return nulled_dict + + if isinstance(obj, Upload): + if obj in files_list: + file_index = files_list.index(obj) + files_map[str(file_index)].append(path) + else: + file_index = len(files_list) + files_list.append(obj) + files_map[str(file_index)] = [path] + return None + + return obj + + nulled_variables = separate_files("variables", variables) + files: Dict[str, Tuple[str, IO[bytes], str]] = { + str(i): (file_.filename, cast(IO[bytes], file_.content), file_.content_type) + for i, file_ in enumerate(files_list) + } + return nulled_variables, files, files_map + + async def _execute_multipart( + self, + query: str, + operation_name: Optional[str], + variables: Dict[str, Any], + files: Dict[str, Tuple[str, IO[bytes], str]], + files_map: Dict[str, List[str]], + **kwargs: Any, + ) -> httpx.Response: + data = { + "operations": json.dumps( + { + "query": query, + "operationName": operation_name, + "variables": variables, + }, + default=to_jsonable_python, + ), + "map": json.dumps(files_map, default=to_jsonable_python), + } + + return await self.http_client.post( + url=self.url, data=data, files=files, **kwargs + ) + + async def _execute_json( + self, + query: str, + operation_name: Optional[str], + variables: Dict[str, Any], + **kwargs: Any, + ) -> httpx.Response: + headers: Dict[str, str] = {"Content-Type": "application/json"} + headers.update(kwargs.get("headers", {})) + + merged_kwargs: Dict[str, Any] = kwargs.copy() + merged_kwargs["headers"] = headers + + return await self.http_client.post( + url=self.url, + content=json.dumps( + { + "query": query, + "operationName": operation_name, + "variables": variables, + }, + default=to_jsonable_python, + ), + **merged_kwargs, + ) + + async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: + payload: Dict[str, Any] = { + "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value + } + if self.ws_connection_init_payload: + payload["payload"] = self.ws_connection_init_payload + await websocket.send(json.dumps(payload)) + + async def _send_subscribe( + self, + websocket: WebSocketClientProtocol, + operation_id: str, + query: str, + operation_name: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + ) -> None: + payload: Dict[str, Any] = { + "id": operation_id, + "type": GraphQLTransportWSMessageType.SUBSCRIBE.value, + "payload": {"query": query, "operationName": operation_name}, + } + if variables: + payload["payload"]["variables"] = self._convert_dict_to_json_serializable( + variables + ) + await websocket.send(json.dumps(payload)) + + async def _handle_ws_message( + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: Optional[GraphQLTransportWSMessageType] = None, + ) -> Optional[Dict[str, Any]]: + try: + message_dict = json.loads(message) + except json.JSONDecodeError as exc: + raise GraphQLClientInvalidMessageFormat(message=message) from exc + + type_ = message_dict.get("type") + payload = message_dict.get("payload", {}) + + if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: + raise GraphQLClientInvalidMessageFormat(message=message) + + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message received. Expected: {expected_type.value}" + ) + + if type_ == GraphQLTransportWSMessageType.NEXT: + if "data" not in payload: + raise GraphQLClientInvalidMessageFormat(message=message) + return cast(Dict[str, Any], payload["data"]) + + if type_ == GraphQLTransportWSMessageType.COMPLETE: + await websocket.close() + elif type_ == GraphQLTransportWSMessageType.PING: + await websocket.send( + json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) + ) + elif type_ == GraphQLTransportWSMessageType.ERROR: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=payload, data=message_dict + ) + + return None diff --git a/tests/main/custom_operation_builder/graphql_client/base_model.py b/tests/main/custom_operation_builder/graphql_client/base_model.py new file mode 100644 index 00000000..ccde3975 --- /dev/null +++ b/tests/main/custom_operation_builder/graphql_client/base_model.py @@ -0,0 +1,27 @@ +from io import IOBase + +from pydantic import BaseModel as PydanticBaseModel, ConfigDict + + +class UnsetType: + def __bool__(self) -> bool: + return False + + +UNSET = UnsetType() + + +class BaseModel(PydanticBaseModel): + model_config = ConfigDict( + populate_by_name=True, + validate_assignment=True, + arbitrary_types_allowed=True, + protected_namespaces=(), + ) + + +class Upload: + def __init__(self, filename: str, content: IOBase, content_type: str): + self.filename = filename + self.content = content + self.content_type = content_type diff --git a/tests/main/custom_operation_builder/graphql_client/base_operation.py b/tests/main/custom_operation_builder/graphql_client/base_operation.py new file mode 100644 index 00000000..a488cc73 --- /dev/null +++ b/tests/main/custom_operation_builder/graphql_client/base_operation.py @@ -0,0 +1,92 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +from graphql import ( + ArgumentNode, + BooleanValueNode, + FieldNode, + FloatValueNode, + InlineFragmentNode, + IntValueNode, + NamedTypeNode, + NameNode, + ObjectFieldNode, + ObjectValueNode, + SelectionSetNode, + StringValueNode, +) + +from .base_model import BaseModel + + +class GraphQLArgument: + def __init__(self, argument_name: str, value: Any): + self._name = argument_name + self._value = self._convert_value(value) + + def _convert_value( + self, value: Any + ) -> Union[ + StringValueNode, IntValueNode, FloatValueNode, BooleanValueNode, ObjectValueNode + ]: + if isinstance(value, str): + return StringValueNode(value=value) + if isinstance(value, int): + return IntValueNode(value=str(value)) + if isinstance(value, float): + return FloatValueNode(value=str(value)) + if isinstance(value, bool): + return BooleanValueNode(value=value) + if isinstance(value, BaseModel): + fields = [ + ObjectFieldNode(name=NameNode(value=k), value=self._convert_value(v)) + for k, v in value.model_dump().items() + ] + return ObjectValueNode(fields=fields) + raise TypeError(f"Unsupported argument type: {type(value)}") + + def to_ast(self) -> ArgumentNode: + return ArgumentNode(name=NameNode(value=self._name), value=self._value) + + +class GraphQLField: + def __init__(self, field_name: str, **kwargs: Any) -> None: + self._field_name: str = field_name + self._arguments: List[GraphQLArgument] = [ + GraphQLArgument(k, v) for k, v in kwargs.items() if v + ] + self._subfields: List["GraphQLField"] = [] + self._alias: Optional[str] = None + self._inline_fragments: Dict[str, Tuple["GraphQLField", ...]] = {} + + def alias(self, alias: str) -> "GraphQLField": + self._alias = alias + return self + + def _build_field_name(self) -> str: + if self._alias: + return f"{self._alias}: {self._field_name}" + return self._field_name + + def to_ast(self) -> FieldNode: + selections: List[Union[FieldNode, InlineFragmentNode]] = [ + sub_field.to_ast() for sub_field in self._subfields + ] + if self._inline_fragments: + selections.extend( + [ + InlineFragmentNode( + type_condition=NamedTypeNode(name=NameNode(value=name)), + selection_set=SelectionSetNode( + selections=[sub_field.to_ast() for sub_field in subfields] + ), + ) + for name, subfields in self._inline_fragments.items() + ] + ) + return FieldNode( + name=NameNode(value=self._build_field_name()), + arguments=[arg.to_ast() for arg in self._arguments], + selection_set=( + SelectionSetNode(selections=selections) if selections else None + ), + ) diff --git a/tests/main/custom_operation_builder/graphql_client/client.py b/tests/main/custom_operation_builder/graphql_client/client.py new file mode 100644 index 00000000..0dc1ac7c --- /dev/null +++ b/tests/main/custom_operation_builder/graphql_client/client.py @@ -0,0 +1,52 @@ +from typing import Any, Dict + +from graphql import ( + DocumentNode, + NameNode, + OperationDefinitionNode, + OperationType, + SelectionSetNode, + print_ast, +) + +from .async_base_client import AsyncBaseClient +from .base_operation import GraphQLField + + +def gql(q: str) -> str: + return q + + +class AutoGenClient(AsyncBaseClient): + async def execute_custom_operation( + self, *fields: GraphQLField, operation_type: OperationType, operation_name: str + ) -> Dict[str, Any]: + operation_ast = DocumentNode( + definitions=[ + OperationDefinitionNode( + operation=operation_type, + name=NameNode(value=operation_name), + selection_set=SelectionSetNode( + selections=[field.to_ast() for field in fields] + ), + ) + ] + ) + response = await self.execute( + print_ast(operation_ast), operation_name=operation_name + ) + return self.get_data(response) + + async def query(self, *fields: GraphQLField, operation_name: str) -> Dict[str, Any]: + return await self.execute_custom_operation( + *fields, operation_type=OperationType.QUERY, operation_name=operation_name + ) + + async def mutation( + self, *fields: GraphQLField, operation_name: str + ) -> Dict[str, Any]: + return await self.execute_custom_operation( + *fields, + operation_type=OperationType.MUTATION, + operation_name=operation_name + ) diff --git a/tests/main/custom_operation_builder/graphql_client/custom_fields.py b/tests/main/custom_operation_builder/graphql_client/custom_fields.py new file mode 100644 index 00000000..e0db87bc --- /dev/null +++ b/tests/main/custom_operation_builder/graphql_client/custom_fields.py @@ -0,0 +1,100 @@ +from typing import Optional, Union + +from . import ( + AdminGraphQLField, + GuestGraphQLField, + PersonGraphQLField, + PostGraphQLField, + UserGraphQLField, +) +from .base_operation import GraphQLField + + +class AdminFields(GraphQLField): + id: AdminGraphQLField = AdminGraphQLField("id") + name: AdminGraphQLField = AdminGraphQLField("name") + email: AdminGraphQLField = AdminGraphQLField("email") + privileges: AdminGraphQLField = AdminGraphQLField("privileges") + created_at: AdminGraphQLField = AdminGraphQLField("createdAt") + + @classmethod + def metafield(cls, *, key: Optional[str] = None) -> "AdminGraphQLField": + return AdminGraphQLField("metafield", key=key) + + def fields(self, *subfields: AdminGraphQLField) -> "AdminFields": + self._subfields.extend(subfields) + return self + + +class GuestFields(GraphQLField): + id: GuestGraphQLField = GuestGraphQLField("id") + name: GuestGraphQLField = GuestGraphQLField("name") + email: GuestGraphQLField = GuestGraphQLField("email") + visit_count: GuestGraphQLField = GuestGraphQLField("visitCount") + created_at: GuestGraphQLField = GuestGraphQLField("createdAt") + + @classmethod + def metafield(cls, *, key: Optional[str] = None) -> "GuestGraphQLField": + return GuestGraphQLField("metafield", key=key) + + def fields(self, *subfields: GuestGraphQLField) -> "GuestFields": + self._subfields.extend(subfields) + return self + + +class PersonInterface(GraphQLField): + id: PersonGraphQLField = PersonGraphQLField("id") + name: PersonGraphQLField = PersonGraphQLField("name") + email: PersonGraphQLField = PersonGraphQLField("email") + + @classmethod + def metafield(cls, *, key: Optional[str] = None) -> "PersonGraphQLField": + return PersonGraphQLField("metafield", key=key) + + def fields(self, *subfields: PersonGraphQLField) -> "PersonInterface": + self._subfields.extend(subfields) + return self + + def on(self, type_name: str, *subfields: GraphQLField) -> "PersonInterface": + self._inline_fragments[type_name] = subfields + return self + + +class PostFields(GraphQLField): + id: PostGraphQLField = PostGraphQLField("id") + title: PostGraphQLField = PostGraphQLField("title") + content: PostGraphQLField = PostGraphQLField("content") + + @classmethod + def author(cls) -> "PersonInterface": + return PersonInterface("author") + + published_at: PostGraphQLField = PostGraphQLField("publishedAt") + + def fields( + self, *subfields: Union[PostGraphQLField, "PersonInterface"] + ) -> "PostFields": + self._subfields.extend(subfields) + return self + + +class UserFields(GraphQLField): + id: UserGraphQLField = UserGraphQLField("id") + name: UserGraphQLField = UserGraphQLField("name") + email: UserGraphQLField = UserGraphQLField("email") + age: UserGraphQLField = UserGraphQLField("age") + role: UserGraphQLField = UserGraphQLField("role") + + @classmethod + def friends(cls) -> "UserFields": + return UserFields("friends") + + created_at: UserGraphQLField = UserGraphQLField("createdAt") + + @classmethod + def metafield(cls, *, key: Optional[str] = None) -> "UserGraphQLField": + return UserGraphQLField("metafield", key=key) + + def fields(self, *subfields: Union[UserGraphQLField, "UserFields"]) -> "UserFields": + self._subfields.extend(subfields) + return self diff --git a/tests/main/custom_operation_builder/graphql_client/custom_mutations.py b/tests/main/custom_operation_builder/graphql_client/custom_mutations.py new file mode 100644 index 00000000..7bf5894b --- /dev/null +++ b/tests/main/custom_operation_builder/graphql_client/custom_mutations.py @@ -0,0 +1,63 @@ +from typing import Any, Optional + +from .custom_fields import PostFields, UserFields +from .input_types import AddUserInput, UpdateUserInput + + +class Mutation: + @classmethod + def add_user(cls, *, user_input: Optional[AddUserInput] = None) -> UserFields: + return UserFields(field_name="addUser", user_input=user_input) + + @classmethod + def update_user( + cls, + *, + user_id: Optional[str] = None, + user_input: Optional[UpdateUserInput] = None + ) -> UserFields: + return UserFields( + field_name="updateUser", user_id=user_id, user_input=user_input + ) + + @classmethod + def delete_user(cls, *, user_id: Optional[str] = None) -> UserFields: + return UserFields(field_name="deleteUser", user_id=user_id) + + @classmethod + def add_post( + cls, + *, + title: Optional[str] = None, + content: Optional[str] = None, + authorId: Optional[str] = None, + publishedAt: Optional[Any] = None + ) -> PostFields: + return PostFields( + field_name="addPost", + title=title, + content=content, + authorId=authorId, + publishedAt=publishedAt, + ) + + @classmethod + def update_post( + cls, + *, + post_id: Optional[str] = None, + title: Optional[str] = None, + content: Optional[str] = None, + publishedAt: Optional[Any] = None + ) -> PostFields: + return PostFields( + field_name="updatePost", + post_id=post_id, + title=title, + content=content, + publishedAt=publishedAt, + ) + + @classmethod + def delete_post(cls, *, post_id: Optional[str] = None) -> PostFields: + return PostFields(field_name="deletePost", post_id=post_id) diff --git a/tests/main/custom_operation_builder/graphql_client/custom_queries.py b/tests/main/custom_operation_builder/graphql_client/custom_queries.py new file mode 100644 index 00000000..dd880ca0 --- /dev/null +++ b/tests/main/custom_operation_builder/graphql_client/custom_queries.py @@ -0,0 +1,64 @@ +from typing import Optional + +from .custom_fields import ( + AdminFields, + GuestFields, + PersonInterface, + PostFields, + UserFields, +) +from .custom_typing_fields import GraphQLField, SearchResultUnion + + +class Query: + @classmethod + def hello(cls) -> GraphQLField: + return GraphQLField(field_name="hello") + + @classmethod + def greeting(cls, *, name: Optional[str] = None) -> GraphQLField: + return GraphQLField(field_name="greeting", name=name) + + @classmethod + def user(cls, *, user_id: Optional[str] = None) -> UserFields: + return UserFields(field_name="user", user_id=user_id) + + @classmethod + def users(cls) -> UserFields: + return UserFields(field_name="users") + + @classmethod + def admin(cls, *, admin_id: Optional[str] = None) -> AdminFields: + return AdminFields(field_name="admin", admin_id=admin_id) + + @classmethod + def admins(cls) -> AdminFields: + return AdminFields(field_name="admins") + + @classmethod + def guest(cls, *, guest_id: Optional[str] = None) -> GuestFields: + return GuestFields(field_name="guest", guest_id=guest_id) + + @classmethod + def guests(cls) -> GuestFields: + return GuestFields(field_name="guests") + + @classmethod + def search(cls, *, text: Optional[str] = None) -> SearchResultUnion: + return SearchResultUnion(field_name="search", text=text) + + @classmethod + def posts(cls) -> PostFields: + return PostFields(field_name="posts") + + @classmethod + def post(cls, *, post_id: Optional[str] = None) -> PostFields: + return PostFields(field_name="post", post_id=post_id) + + @classmethod + def person(cls, *, person_id: Optional[str] = None) -> PersonInterface: + return PersonInterface(field_name="person", person_id=person_id) + + @classmethod + def people(cls) -> PersonInterface: + return PersonInterface(field_name="people") diff --git a/tests/main/custom_operation_builder/graphql_client/custom_typing_fields.py b/tests/main/custom_operation_builder/graphql_client/custom_typing_fields.py new file mode 100644 index 00000000..7c78ba55 --- /dev/null +++ b/tests/main/custom_operation_builder/graphql_client/custom_typing_fields.py @@ -0,0 +1,27 @@ +from .base_operation import GraphQLField + + +class PersonGraphQLField(GraphQLField): + pass + + +class SearchResultUnion(GraphQLField): + def on(self, type_name: str, *subfields: GraphQLField) -> "SearchResultUnion": + self._inline_fragments[type_name] = subfields + return self + + +class UserGraphQLField(GraphQLField): + pass + + +class AdminGraphQLField(GraphQLField): + pass + + +class GuestGraphQLField(GraphQLField): + pass + + +class PostGraphQLField(GraphQLField): + pass diff --git a/tests/main/custom_operation_builder/graphql_client/enums.py b/tests/main/custom_operation_builder/graphql_client/enums.py new file mode 100644 index 00000000..45e68c85 --- /dev/null +++ b/tests/main/custom_operation_builder/graphql_client/enums.py @@ -0,0 +1,7 @@ +from enum import Enum + + +class Role(str, Enum): + USER = "USER" + ADMIN = "ADMIN" + GUEST = "GUEST" diff --git a/tests/main/custom_operation_builder/graphql_client/exceptions.py b/tests/main/custom_operation_builder/graphql_client/exceptions.py new file mode 100644 index 00000000..b34acfe1 --- /dev/null +++ b/tests/main/custom_operation_builder/graphql_client/exceptions.py @@ -0,0 +1,83 @@ +from typing import Any, Dict, List, Optional, Union + +import httpx + + +class GraphQLClientError(Exception): + """Base exception.""" + + +class GraphQLClientHttpError(GraphQLClientError): + def __init__(self, status_code: int, response: httpx.Response) -> None: + self.status_code = status_code + self.response = response + + def __str__(self) -> str: + return f"HTTP status code: {self.status_code}" + + +class GraphQLClientInvalidResponseError(GraphQLClientError): + def __init__(self, response: httpx.Response) -> None: + self.response = response + + def __str__(self) -> str: + return "Invalid response format." + + +class GraphQLClientGraphQLError(GraphQLClientError): + def __init__( + self, + message: str, + locations: Optional[List[Dict[str, int]]] = None, + path: Optional[List[str]] = None, + extensions: Optional[Dict[str, object]] = None, + orginal: Optional[Dict[str, object]] = None, + ): + self.message = message + self.locations = locations + self.path = path + self.extensions = extensions + self.orginal = orginal + + def __str__(self) -> str: + return self.message + + @classmethod + def from_dict(cls, error: Dict[str, Any]) -> "GraphQLClientGraphQLError": + return cls( + message=error["message"], + locations=error.get("locations"), + path=error.get("path"), + extensions=error.get("extensions"), + orginal=error, + ) + + +class GraphQLClientGraphQLMultiError(GraphQLClientError): + def __init__( + self, + errors: List[GraphQLClientGraphQLError], + data: Optional[Dict[str, Any]] = None, + ): + self.errors = errors + self.data = data + + def __str__(self) -> str: + return "; ".join(str(e) for e in self.errors) + + @classmethod + def from_errors_dicts( + cls, errors_dicts: List[Dict[str, Any]], data: Optional[Dict[str, Any]] = None + ) -> "GraphQLClientGraphQLMultiError": + return cls( + errors=[GraphQLClientGraphQLError.from_dict(e) for e in errors_dicts], + data=data, + ) + + +class GraphQLClientInvalidMessageFormat(GraphQLClientError): + def __init__(self, message: Union[str, bytes]) -> None: + self.message = message + + def __str__(self) -> str: + return "Invalid message format." diff --git a/tests/main/custom_operation_builder/graphql_client/input_types.py b/tests/main/custom_operation_builder/graphql_client/input_types.py new file mode 100644 index 00000000..8c993ec8 --- /dev/null +++ b/tests/main/custom_operation_builder/graphql_client/input_types.py @@ -0,0 +1,22 @@ +from typing import Any, Optional + +from pydantic import Field + +from .base_model import BaseModel +from .enums import Role + + +class AddUserInput(BaseModel): + name: str + age: Optional[int] = None + email: Optional[str] = None + role: Optional[Role] = Role.USER + created_at: Optional[Any] = Field(alias="createdAt", default=None) + + +class UpdateUserInput(BaseModel): + name: Optional[str] = None + age: Optional[int] = None + email: Optional[str] = None + role: Optional[Role] = None + created_at: Optional[Any] = Field(alias="createdAt", default=None) diff --git a/tests/main/custom_operation_builder/test_operation_build.py b/tests/main/custom_operation_builder/test_operation_build.py new file mode 100644 index 00000000..ae3c7ac4 --- /dev/null +++ b/tests/main/custom_operation_builder/test_operation_build.py @@ -0,0 +1,450 @@ +from graphql import print_ast + +from .graphql_client.custom_fields import ( + AdminFields, + GuestFields, + PersonInterface, + PostFields, + UserFields, +) +from .graphql_client.custom_mutations import Mutation +from .graphql_client.custom_queries import Query +from .graphql_client.enums import Role +from .graphql_client.input_types import AddUserInput, UpdateUserInput + + +def test_simple_hello(): + built_query = print_ast(Query.hello().to_ast()) + expected_query = "hello" + assert built_query == expected_query + + +def test_greeting_with_name(): + built_query = print_ast(Query.greeting(name="Alice").to_ast()) + expected_query = 'greeting(name: "Alice")' + assert built_query == expected_query + + +def test_user_by_id(): + built_query = print_ast( + Query.user(user_id="1") + .fields( + UserFields.id, + UserFields.name, + UserFields.age, + UserFields.email, + ) + .to_ast() + ) + expected_query = 'user(user_id: "1") {\n id\n name\n age\n email\n}' + assert built_query == expected_query + + +def test_all_users(): + built_query = print_ast( + Query.users() + .fields( + UserFields.id, + UserFields.name, + UserFields.age, + UserFields.email, + ) + .to_ast() + ) + expected_query = "users {\n id\n name\n age\n email\n}" + assert built_query == expected_query + + +def test_user_with_friends(): + built_query = print_ast( + Query.user(user_id="1") + .fields( + UserFields.id, + UserFields.name, + UserFields.age, + UserFields.email, + UserFields.friends().fields( + UserFields.id, + UserFields.name, + ), + UserFields.created_at, + ) + .to_ast() + ) + expected_query = ( + 'user(user_id: "1") {\n' + " id\n" + " name\n" + " age\n" + " email\n" + " friends {\n" + " id\n" + " name\n" + " }\n" + " createdAt\n" + "}" + ) + assert built_query == expected_query + + +def test_search_example(): + built_query = print_ast( + Query.search(text="example") + .on( + "User", + UserFields.id, + UserFields.name, + UserFields.email, + UserFields.created_at, + ) + .on( + "Admin", + AdminFields.id, + AdminFields.name, + AdminFields.privileges, + AdminFields.created_at, + ) + .on( + "Guest", + GuestFields.id, + GuestFields.name, + GuestFields.visit_count, + GuestFields.created_at, + ) + .to_ast() + ) + expected_query = ( + 'search(text: "example") {\n' + " ... on User {\n" + " id\n" + " name\n" + " email\n" + " createdAt\n" + " }\n" + " ... on Admin {\n" + " id\n" + " name\n" + " privileges\n" + " createdAt\n" + " }\n" + " ... on Guest {\n" + " id\n" + " name\n" + " visitCount\n" + " createdAt\n" + " }\n" + "}" + ) + assert built_query == expected_query + + +def test_posts_with_authors(): + built_query = print_ast( + Query.posts() + .fields( + PostFields.id, + PostFields.title, + PostFields.content, + PostFields.author().fields( + PersonInterface.id, PersonInterface.name, PersonInterface.email + ), + PostFields.published_at, + ) + .to_ast() + ) + expected_query = ( + "posts {\n" + " id\n" + " title\n" + " content\n" + " author {\n" + " id\n" + " name\n" + " email\n" + " }\n" + " publishedAt\n" + "}" + ) + assert built_query == expected_query + + +def test_get_person(): + built_query = print_ast( + Query.person(person_id="1") + .fields(PersonInterface.id, PersonInterface.name, PersonInterface.email) + .on("User", UserFields.age, UserFields.role) + .on("Admin", AdminFields.privileges) + .to_ast() + ) + expected_query = ( + 'person(person_id: "1") {\n' + " id\n" + " name\n" + " email\n" + " ... on User {\n" + " age\n" + " role\n" + " }\n" + " ... on Admin {\n" + " privileges\n" + " }\n" + "}" + ) + assert built_query == expected_query + + +def test_get_people(): + built_query = print_ast( + Query.people() + .fields(PersonInterface.id, PersonInterface.name, PersonInterface.email) + .on("User", UserFields.age, UserFields.role) + .on("Admin", AdminFields.privileges) + .to_ast() + ) + expected_query = ( + "people {\n" + " id\n" + " name\n" + " email\n" + " ... on User {\n" + " age\n" + " role\n" + " }\n" + " ... on Admin {\n" + " privileges\n" + " }\n" + "}" + ) + assert built_query == expected_query + + +def test_add_user_mutation(): + built_mutation = print_ast( + Mutation.add_user( + user_input=AddUserInput( + name="bob", + age=30, + email="bob@example.com", + role=Role.ADMIN, + createdAt="2024-06-07T00:00:00.000Z", + ) + ) + .fields( + UserFields.id, + UserFields.name, + UserFields.age, + UserFields.email, + UserFields.role, + UserFields.created_at, + ) + .to_ast() + ) + expected_mutation = ( + "addUser(\n" + ' user_input: {name: "bob", age: 30, email: "bob@example.com", role: "ADMIN", ' + 'created_at: "2024-06-07T00:00:00.000Z"}\n' + ") {\n" + " id\n" + " name\n" + " age\n" + " email\n" + " role\n" + " createdAt\n" + "}" + ) + assert built_mutation == expected_mutation + + +def test_update_user_mutation(): + built_mutation = print_ast( + Mutation.update_user( + user_id="1", + user_input=UpdateUserInput( + name="Alice Updated", + age=25, + email="alice.updated@example.com", + role=Role.USER, + createdAt="2024-06-07T00:00:00.000Z", + ), + ) + .fields( + UserFields.id, + UserFields.name, + UserFields.age, + UserFields.email, + UserFields.role, + UserFields.created_at, + ) + .to_ast() + ) + expected_mutation = ( + "updateUser(\n" + ' user_id: "1"\n' + " user_input: " + '{name: "Alice Updated", age: 25, email: "alice.updated@example.com", ' + 'role: "USER", created_at: "2024-06-07T00:00:00.000Z"}\n' + ") {\n" + " id\n" + " name\n" + " age\n" + " email\n" + " role\n" + " createdAt\n" + "}" + ) + assert built_mutation == expected_mutation + + +def test_delete_user_mutation(): + built_mutation = print_ast( + Mutation.delete_user(user_id="1") + .fields( + UserFields.id, + UserFields.name, + ) + .to_ast() + ) + expected_mutation = 'deleteUser(user_id: "1") {\n id\n name\n}' + assert built_mutation == expected_mutation + + +def test_add_post_mutation(): + built_mutation = print_ast( + Mutation.add_post( + title="New Post", + content="This is the content", + authorId="1", + publishedAt="2024-06-07T00:00:00.000Z", + ) + .fields( + PostFields.id, + PostFields.title, + PostFields.content, + PostFields.author().fields(PersonInterface.id, PersonInterface.name), + PostFields.published_at, + ) + .to_ast() + ) + expected_mutation = ( + "addPost(\n" + ' title: "New Post"\n' + ' content: "This is the content"\n' + ' authorId: "1"\n' + ' publishedAt: "2024-06-07T00:00:00.000Z"\n' + ") {\n" + " id\n" + " title\n" + " content\n" + " author {\n" + " id\n" + " name\n" + " }\n" + " publishedAt\n" + "}" + ) + assert built_mutation == expected_mutation + + +def test_update_post_mutation(): + built_mutation = print_ast( + Mutation.update_post( + post_id="1", + title="Updated Title", + content="Updated Content", + publishedAt="2024-06-07T00:00:00.000Z", + ) + .fields( + PostFields.id, + PostFields.title, + PostFields.content, + PostFields.published_at, + ) + .to_ast() + ) + expected_mutation = ( + "updatePost(\n" + ' post_id: "1"\n' + ' title: "Updated Title"\n' + ' content: "Updated Content"\n' + ' publishedAt: "2024-06-07T00:00:00.000Z"\n' + ") {\n" + " id\n" + " title\n" + " content\n" + " publishedAt\n" + "}" + ) + assert built_mutation == expected_mutation + + +def test_delete_post_mutation(): + built_mutation = print_ast( + Mutation.delete_post(post_id="1") + .fields( + PostFields.id, + PostFields.title, + ) + .to_ast() + ) + expected_mutation = 'deletePost(post_id: "1") {\n id\n title\n}' + assert built_mutation == expected_mutation + + +def test_user_specific_fields(): + built_query = print_ast( + Query.user(user_id="1").fields(UserFields.id, UserFields.name).to_ast() + ) + expected_query = 'user(user_id: "1") {\n id\n name\n}' + assert built_query == expected_query + + +def test_user_with_friends_specific_fields(): + built_query = print_ast( + Query.user(user_id="1") + .fields( + UserFields.id, + UserFields.name, + UserFields.friends().fields(UserFields.id, UserFields.name), + UserFields.created_at, + ) + .to_ast() + ) + expected_query = ( + 'user(user_id: "1") {\n' + " id\n" + " name\n" + " friends {\n" + " id\n" + " name\n" + " }\n" + " createdAt\n" + "}" + ) + assert built_query == expected_query + + +def test_people_with_metadata(): + built_query = print_ast( + Query.people() + .fields( + PersonInterface.id, + PersonInterface.name, + PersonInterface.email, + PersonInterface.metafield(key="bio"), + ) + .on("User", UserFields.age, UserFields.role) + .to_ast() + ) + expected_query = ( + "people {\n" + " id\n" + " name\n" + " email\n" + ' metafield(key: "bio")\n' + " ... on User {\n" + " age\n" + " role\n" + " }\n" + "}" + ) + assert built_query == expected_query diff --git a/tests/main/test_main.py b/tests/main/test_main.py index 9b94f1a7..2182a1de 100644 --- a/tests/main/test_main.py +++ b/tests/main/test_main.py @@ -197,6 +197,14 @@ def test_main_shows_version(): "interface_as_fragment", CLIENTS_PATH / "interface_as_fragment" / "expected_client", ), + ( + ( + CLIENTS_PATH / "custom_query_builder" / "pyproject.toml", + (CLIENTS_PATH / "custom_query_builder" / "schema.graphql",), + ), + "example_client", + CLIENTS_PATH / "custom_query_builder" / "expected_client", + ), ], indirect=["project_dir"], ) From 8ca600406cb2a5574639cd13f4254ad5a3912b5e Mon Sep 17 00:00:00 2001 From: Damian Czajkowski Date: Tue, 11 Jun 2024 14:32:18 +0200 Subject: [PATCH 02/11] Update version and changelog --- CHANGELOG.md | 1 + pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f4d7f2cb..7b9b9f9a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ - Re-added `model_rebuild` calls for input types with forward references. - Fixed fragments on interfaces being omitted from generated client. - Fixed `@Include` directive result type when using `convert_to_snake_case` option. +- Added Custom query builder feature. ## 0.13.0 (2024-03-4) diff --git a/pyproject.toml b/pyproject.toml index 89f8bb40..81ab64dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "hatchling.build" name = "ariadne-codegen" description = "Generate fully typed GraphQL client from schema, queries and mutations!" authors = [{ name = "Mirumee Software", email = "hello@mirumee.com" }] -version = "0.13.0" +version = "0.14.0.dev1" readme = "README.md" license = { file = "LICENSE" } classifiers = [ From b330edd5561833eedba9e269dce65a7d8b35410c Mon Sep 17 00:00:00 2001 From: Damian Czajkowski Date: Tue, 9 Jul 2024 11:34:01 +0200 Subject: [PATCH 03/11] Fix scalar, enum import. Add support for custom scalars. Add basic support for plugins --- ariadne_codegen/client_generators/client.py | 500 ++++++++++++++---- .../client_generators/constants.py | 4 + .../client_generators/custom_fields.py | 124 ++++- .../client_generators/custom_operation.py | 136 ++++- .../dependencies/base_operation.py | 104 ++-- ariadne_codegen/client_generators/package.py | 16 +- pyproject.toml | 2 +- .../expected_client/base_operation.py | 104 ++-- .../expected_client/client.py | 60 ++- .../expected_client/custom_fields.py | 56 +- .../expected_client/custom_mutations.py | 6 +- .../expected_client/custom_queries.py | 22 +- 12 files changed, 827 insertions(+), 307 deletions(-) diff --git a/ariadne_codegen/client_generators/client.py b/ariadne_codegen/client_generators/client.py index 5b245463..7bc42a91 100644 --- a/ariadne_codegen/client_generators/client.py +++ b/ariadne_codegen/client_generators/client.py @@ -16,6 +16,7 @@ generate_class_def, generate_comp, generate_constant, + generate_dict, generate_expr, generate_import_from, generate_keyword, @@ -44,15 +45,19 @@ LIST, MODEL_VALIDATE_METHOD, NAME_NODE, + NAMED_TYPE_NODE, OPERATION_DEFINITION_NODE, OPERATION_TYPE, OPTIONAL, PRINT_AST, SELECTION_SET_NODE, + TUPLE, TYPING_MODULE, UNION, UNSET_IMPORT, UPLOAD_IMPORT, + VARIABLE_DEFINITION_NODE, + VARIABLE_NODE, ) from .scalars import ScalarData, generate_scalar_imports @@ -207,133 +212,444 @@ def add_method( generate_import_from(names=[return_type], from_=return_type_module, level=1) ) - def add_execute_custom_operation_method(self): - self._add_import( - generate_import_from( - [ - DOCUMENT_NODE, - OPERATION_DEFINITION_NODE, - NAME_NODE, - SELECTION_SET_NODE, - PRINT_AST, + def create_combine_variables_method(self): + method_body = [ + generate_assign( + targets=["variables_types_combined"], + value=generate_dict(), + ), + generate_assign( + targets=["processed_variables_combined"], + value=generate_dict(), + ), + ast.For( + target=generate_tuple( + elts=[ + generate_name("idx"), + generate_name("field"), + ], + ), + iter=generate_call( + func=generate_name("enumerate"), + args=[generate_name("fields")], + ), + body=[ + generate_expr( + value=generate_call( + func=generate_attribute( + value=generate_name("variables_types_combined"), + attr="update", + ), + args=[ + generate_call( + func=generate_attribute( + value=generate_name("field"), + attr="get_variables_types", + ), + args=[generate_name("idx")], + ) + ], + ) + ), + generate_expr( + value=generate_call( + func=generate_attribute( + value=generate_name("processed_variables_combined"), + attr="update", + ), + args=[ + generate_call( + func=generate_attribute( + value=generate_name("field"), + attr="get_processed_variables", + ), + args=[generate_name("idx")], + ) + ], + ) + ), ], - GRAPHQL_MODULE, - ) + orelse=[], + lineno=1, + ), + generate_return( + value=generate_tuple( + elts=[ + generate_name("variables_types_combined"), + generate_name("processed_variables_combined"), + ], + ) + ), + ] + + args = generate_arguments( + args=[ + generate_arg("self"), + generate_arg( + name="fields", + annotation=generate_subscript( + generate_name(TUPLE), + generate_tuple( + [ + generate_name("GraphQLField"), + generate_name("..."), + ] + ), + ), + ), + ], ) - self._add_import( - generate_import_from( - [BASE_GRAPHQL_FIELD_CLASS_NAME], BASE_OPERATION_FILE_PATH.stem, level=1 - ) + + returns = generate_subscript( + generate_name(TUPLE), + generate_tuple( + [ + generate_subscript( + generate_name(DICT), + generate_tuple([generate_name("str"), generate_name("Any")]), + ), + generate_subscript( + generate_name(DICT), + generate_tuple([generate_name("str"), generate_name("Any")]), + ), + ] + ), ) - execute_await = generate_await( - value=generate_call( - func=generate_attribute(value=generate_name("self"), attr="execute"), - args=[ - generate_call( - func=generate_name("print_ast"), - args=[generate_name("operation_ast")], - ) - ], - keywords=[ - generate_keyword( - arg="operation_name", value=generate_name("operation_name") - ) - ], - ) + + method_def = generate_method_definition( + name="_combine_variables", + arguments=args, + body=method_body, + decorator_list=[], + return_type=returns, ) - operation_definition_node = generate_call( - func=generate_name("OperationDefinitionNode"), - keywords=[ - generate_keyword( - arg="operation", value=generate_name("operation_type") - ), - generate_keyword( - arg="name", - value=generate_call( - func=generate_name("NameNode"), + return method_def + + def create_build_variable_definitions_method(self): + method_body = [ + generate_return( + value=generate_list_comp( + elt=generate_call( + func=generate_name("VariableDefinitionNode"), keywords=[ generate_keyword( - arg="value", value=generate_name("operation_name") - ) - ], - ), - ), - generate_keyword( - arg="selection_set", - value=generate_call( - func=generate_name("SelectionSetNode"), - keywords=[ + arg="variable", + value=generate_call( + func=generate_name("VariableNode"), + keywords=[ + generate_keyword( + arg="name", + value=generate_call( + func=generate_name("NameNode"), + keywords=[ + generate_keyword( + arg="value", + value=generate_name("var_name"), + ) + ], + ), + ) + ], + ), + ), generate_keyword( - arg="selections", - value=generate_list_comp( - elt=generate_call( - func=generate_attribute( - value=generate_name("field"), - attr="to_ast", - ), - ), - generators=[ - generate_comp( - target="field", - iter_="fields", + arg="type", + value=generate_call( + func=generate_name("NamedTypeNode"), + keywords=[ + generate_keyword( + arg="name", + value=generate_call( + func=generate_name("NameNode"), + keywords=[ + generate_keyword( + arg="value", + value=generate_name( + "var_value", + ), + ) + ], + ), ) ], ), - ) + ), ], ), - ), - ], - ) - operation_ast = generate_call( - func=generate_name("DocumentNode"), - keywords=[ - generate_keyword( - arg="definitions", - value=generate_list(elements=[operation_definition_node]), + generators=[ + generate_comp( + target="var_name, var_value", + iter_="variables_types_combined.items()", + ) + ], ) - ], - ) - body_return = generate_return( - value=generate_call( - func=generate_attribute(value=generate_name("self"), attr="get_data"), - args=[generate_name("response")], ) + ] + return generate_method_definition( + name="_build_variable_definitions", + arguments=generate_arguments( + args=[ + generate_arg("self"), + generate_arg( + "variables_types_combined", + annotation=generate_subscript( + generate_name(DICT), + generate_tuple( + [ + generate_name("str"), + generate_name("str"), + ] + ), + ), + ), + ] + ), + body=method_body, + return_type=generate_subscript( + generate_name("List"), generate_name("VariableDefinitionNode") + ), ) - async_def_node = generate_async_method_definition( - name="execute_custom_operation", + + def create_build_operation_ast_method(self): + keywords = [ + generate_keyword( + arg="definitions", + value=generate_list( + elements=[ + generate_call( + func=generate_name("OperationDefinitionNode"), + keywords=[ + generate_keyword( + arg="operation", + value=generate_name("operation_type"), + ), + generate_keyword( + arg="name", + value=generate_call( + func=generate_name("NameNode"), + keywords=[ + generate_keyword( + arg="value", + value=generate_name( + "operation_name", + ), + ) + ], + ), + ), + generate_keyword( + arg="variable_definitions", + value=generate_name( + "variable_definitions", + ), + ), + generate_keyword( + arg="selection_set", + value=generate_call( + func=generate_name( + "SelectionSetNode", + ), + keywords=[ + generate_keyword( + arg="selections", + value=generate_list_comp( + elt=generate_call( + func=generate_attribute( + value=generate_name( + "field", + ), + attr="to_ast", + ), + args=[generate_name("idx")], + ), + generators=[ + generate_comp( + target="idx, field", + iter_="enumerate(fields)", + ) + ], + ), + ) + ], + ), + ), + ], + ) + ] + ), + ) + ] + method_body = [ + generate_return( + value=generate_call( + func=generate_name("DocumentNode"), + keywords=keywords, + ) + ) + ] + return generate_method_definition( + name="_build_operation_ast", arguments=generate_arguments( args=[ generate_arg("self"), generate_arg( - "*fields", - annotation=generate_name("GraphQLField"), + "fields", + annotation=generate_subscript( + generate_name(TUPLE), + generate_tuple( + [ + generate_name("GraphQLField"), + generate_name("..."), + ] + ), + ), ), generate_arg( "operation_type", annotation=generate_name("OperationType"), ), generate_arg("operation_name", annotation=generate_name("str")), + generate_arg( + "variable_definitions", + annotation=generate_subscript( + generate_name("List"), + generate_name("VariableDefinitionNode"), + ), + ), + ] + ), + body=method_body, + return_type=generate_name("DocumentNode"), + ) + + def create_execute_custom_operation_method(self): + variables_types_combined = generate_name("variables_types_combined") + processed_variables_combined = generate_name("processed_variables_combined") + method_body = [ + ast.Assign( + targets=[ + generate_tuple( + elts=[variables_types_combined, processed_variables_combined], + ) ], + value=generate_call( + func=generate_attribute( + value=generate_name("self"), attr="_combine_variables" + ), + args=[generate_name("fields")], + ), + lineno=1, ), - body=[ - generate_assign( - targets=["operation_ast"], - value=operation_ast, + generate_assign( + targets=["variable_definitions"], + value=generate_call( + func=generate_attribute( + value=generate_name("self"), attr="_build_variable_definitions" + ), + args=[generate_name("variables_types_combined")], ), - generate_assign( - targets=["response"], - value=execute_await, + ), + generate_assign( + targets=["operation_ast"], + value=generate_call( + func=generate_attribute( + value=generate_name("self"), attr="_build_operation_ast" + ), + args=[ + generate_name("fields"), + generate_name("operation_type"), + generate_name("operation_name"), + generate_name("variable_definitions"), + ], ), - body_return, - ], + ), + generate_assign( + targets=["response"], + value=generate_await( + value=generate_call( + func=generate_attribute( + value=generate_name("self"), + attr="execute", + ), + args=[ + generate_call( + func=generate_name("print_ast"), + args=[generate_name("operation_ast")], + ) + ], + keywords=[ + generate_keyword( + arg="variables", + value=generate_name("processed_variables_combined"), + ), + generate_keyword( + arg="operation_name", + value=generate_name("operation_name"), + ), + ], + ) + ), + ), + generate_return( + value=generate_call( + func=generate_attribute( + value=generate_name("self"), + attr="get_data", + ), + args=[generate_name("response")], + ) + ), + ] + return generate_async_method_definition( + name="execute_custom_operation", + arguments=generate_arguments( + args=[ + generate_arg("self"), + generate_arg("*fields", annotation=generate_name("GraphQLField")), + generate_arg( + "operation_type", + annotation=generate_name( + "OperationType", + ), + ), + generate_arg("operation_name", annotation=generate_name("str")), + ] + ), + body=method_body, return_type=generate_subscript( generate_name(DICT), generate_tuple([generate_name("str"), generate_name("Any")]), ), ) - self._class_def.body.append(async_def_node) + + def add_execute_custom_operation_method(self): + self._add_import( + generate_import_from( + [ + DOCUMENT_NODE, + OPERATION_DEFINITION_NODE, + NAME_NODE, + SELECTION_SET_NODE, + PRINT_AST, + VARIABLE_DEFINITION_NODE, + VARIABLE_NODE, + NAMED_TYPE_NODE, + ], + GRAPHQL_MODULE, + ) + ) + self._add_import( + generate_import_from( + [BASE_GRAPHQL_FIELD_CLASS_NAME], BASE_OPERATION_FILE_PATH.stem, level=1 + ) + ) + self._add_import(generate_import_from([DICT, TUPLE, LIST, ANY], "typing")) + + self._class_def.body.append(self.create_execute_custom_operation_method()) + self._class_def.body.append(self.create_combine_variables_method()) + self._class_def.body.append(self.create_build_variable_definitions_method()) + self._class_def.body.append(self.create_build_operation_ast_method()) def create_custom_operation_method(self, name, operation_type): self._add_import( diff --git a/ariadne_codegen/client_generators/constants.py b/ariadne_codegen/client_generators/constants.py index f9927e6b..a3258def 100644 --- a/ariadne_codegen/client_generators/constants.py +++ b/ariadne_codegen/client_generators/constants.py @@ -19,6 +19,7 @@ TYPE = "Type" TYPE_CHECKING = "TYPE_CHECKING" DICT = "Dict" +TUPLE = "Tuple" CALLABLE = "Callable" ANNOTATED = "Annotated" LITERAL = "Literal" @@ -29,6 +30,9 @@ SELECTION_SET_NODE = "SelectionSetNode" PRINT_AST = "print_ast" OPERATION_TYPE = "OperationType" +VARIABLE_DEFINITION_NODE = "VariableDefinitionNode" +VARIABLE_NODE = "VariableNode" +NAMED_TYPE_NODE = "NamedTypeNode" HTTPX = "httpx" HTTPX_RESPONSE = "httpx.Response" diff --git a/ariadne_codegen/client_generators/custom_fields.py b/ariadne_codegen/client_generators/custom_fields.py index b223b40c..7b80a62a 100644 --- a/ariadne_codegen/client_generators/custom_fields.py +++ b/ariadne_codegen/client_generators/custom_fields.py @@ -1,17 +1,23 @@ import ast -from typing import Dict, List, Optional, Set, Union, cast +from typing import Dict, List, Optional, Set, Tuple, Union, cast from graphql import ( GraphQLEnumType, GraphQLInputObjectType, GraphQLInterfaceType, + GraphQLNonNull, GraphQLObjectType, GraphQLScalarType, GraphQLSchema, GraphQLUnionType, ) +from ariadne_codegen.client_generators.scalars import ( + ScalarData, + generate_scalar_imports, +) from ariadne_codegen.exceptions import ParsingError +from ariadne_codegen.plugins.manager import PluginManager from ..codegen import ( generate_ann_assign, @@ -22,6 +28,7 @@ generate_call, generate_class_def, generate_constant, + generate_dict, generate_expr, generate_import_from, generate_keyword, @@ -51,10 +58,12 @@ def __init__( self, schema: GraphQLSchema, convert_to_snake_case: bool = True, - custom_scalars=None, + custom_scalars: Optional[Dict[str, ScalarData]] = None, + plugin_manager: Optional[PluginManager] = None, ) -> None: self.schema = schema self.convert_to_snake_case = convert_to_snake_case + self.plugin_manager = plugin_manager self.custom_scalars = custom_scalars if custom_scalars else {} self._visited_types: Set[str] = set() self._field_classes: Set[str] = set() @@ -66,7 +75,8 @@ def __init__( level=1, ) ] - self._add_import(generate_import_from([OPTIONAL, UNION], TYPING_MODULE)) + self._used_custom_scalars: List[str] = [] + self._add_import(generate_import_from([OPTIONAL, UNION, ANY], TYPING_MODULE)) self._class_defs: List[ast.ClassDef] = self._parse_object_type_definitions( TypeCollector(self.schema).collect() @@ -75,11 +85,13 @@ def __init__( def _add_import(self, import_: Optional[ast.ImportFrom] = None): if not import_: return - + if self.plugin_manager: + import_ = self.plugin_manager.generate_client_import(import_) if import_.names: self._imports.append(import_) def generate(self) -> ast.Module: + self._add_custom_scalar_imports() module = generate_module( body=( cast(List[ast.stmt], self._imports) @@ -224,39 +236,91 @@ def _generate_fields_method( return_type=generate_name(f'"{class_name}"'), ) + def _generate_kw_args_and_defaults(self, operation_args): + kw_only_args = [] + kw_defaults = [] + args = [] + for arg_name, arg_type in operation_args.items(): + arg_name = process_name( + arg_name, + convert_to_snake_case=self.convert_to_snake_case, + ) + arg_final_type = get_final_type(arg_type) + is_required = isinstance(arg_type.type, GraphQLNonNull) + annotation, _ = self._parse_graphql_type_name( + arg_final_type, + not is_required, + ) + arg = generate_arg(name=arg_name, annotation=annotation) + if is_required: + args.append(arg) + else: + kw_only_args.append(arg) + kw_defaults.append(generate_constant(value=None)) + return kw_only_args, kw_defaults, args + + def _get_dict_value(self, name: str, arg_value) -> Union[ast.Name, ast.Call]: + name = process_name( + name, + convert_to_snake_case=self.convert_to_snake_case, + ) + _, used_custom_scalar = self._parse_graphql_type_name(get_final_type(arg_value)) + if used_custom_scalar: + self._used_custom_scalars.append(used_custom_scalar) + scalar_data = self.custom_scalars[used_custom_scalar] + if scalar_data.serialize_name: + return generate_call( + func=generate_name(scalar_data.serialize_name), + args=[generate_name(name)], + ) + return generate_name(name) + + def _generate_arguments_dict(self, operation_args) -> Dict[ast.Constant, ast.Dict]: + arguments_dict = {} + for arg_name, arg_value in operation_args.items(): + final_type = get_final_type(arg_value) + is_required = isinstance(arg_value.type, GraphQLNonNull) + constant_value = f"{final_type.name}!" if is_required else final_type.name + arguments_dict[generate_constant(arg_name)] = generate_dict( + keys=[generate_constant("type"), generate_constant("value")], + values=[ + generate_constant(constant_value), + self._get_dict_value(arg_name, arg_value), + ], + ) + return arguments_dict + def generate_product_type_method( self, name, class_name, arguments=None ) -> ast.FunctionDef: arguments = arguments or {} - return_keywords = [] field_class_name = generate_name(class_name) - field_kwonlyargs = [] - field_kw_defaults: List[Union[ast.expr, None]] = [] - for arg_name, argument in arguments.items(): - argument_final_type = get_final_type(argument.type) - field_kwonlyargs.append( - generate_arg( - name=arg_name, - annotation=self._parse_graphql_type_name(argument_final_type), - ) - ) - field_kw_defaults.append(generate_constant(value=None)) - return_keywords.append( - generate_keyword(arg=arg_name, value=generate_name(arg_name)) - ) + kw_only_args, kw_defaults, args = self._generate_kw_args_and_defaults( + arguments, + ) + return_arguments_dict = self._generate_arguments_dict(arguments) + + return_keyword = generate_keyword( + arg="arguments", + value=generate_dict( + keys=list(return_arguments_dict.keys()), + values=list(return_arguments_dict.values()), + ), + ) + return generate_method_definition( name, arguments=generate_arguments( - args=[generate_arg(name="cls")], - kwonlyargs=field_kwonlyargs, - kw_defaults=field_kw_defaults, + args=[generate_arg(name="cls"), *args], + kwonlyargs=kw_only_args, + kw_defaults=kw_defaults, ), body=[ generate_return( value=generate_call( func=field_class_name, args=[generate_constant(name)], - keywords=return_keywords, + keywords=[return_keyword], ) ), ], @@ -300,9 +364,9 @@ def _generate_on_method(self, class_name: str) -> ast.FunctionDef: def _parse_graphql_type_name( self, type_, nullable: bool = True - ) -> Union[ast.Name, ast.Subscript]: + ) -> Tuple[Union[ast.Name, ast.Subscript], Optional[str]]: name = type_.name - + used_custom_scalar = None if isinstance(type_, GraphQLInputObjectType): self._add_import( generate_import_from(names=[name], from_="input_types", level=1) @@ -321,8 +385,16 @@ def _parse_graphql_type_name( ) ) else: + used_custom_scalar = name name = self.custom_scalars[name].type_name + self._used_custom_scalars.append(used_custom_scalar) else: raise ParsingError(f"Incorrect argument type {name}") - return generate_annotation_name(name, nullable) + return generate_annotation_name(name, nullable), used_custom_scalar + + def _add_custom_scalar_imports(self): + for custom_scalar_name in self._used_custom_scalars: + scalar_data = self.custom_scalars[custom_scalar_name] + for import_ in generate_scalar_imports(scalar_data): + self._add_import(import_) diff --git a/ariadne_codegen/client_generators/custom_operation.py b/ariadne_codegen/client_generators/custom_operation.py index c0d9b2e7..d0545ad6 100644 --- a/ariadne_codegen/client_generators/custom_operation.py +++ b/ariadne_codegen/client_generators/custom_operation.py @@ -6,14 +6,12 @@ GraphQLFieldMap, GraphQLInputObjectType, GraphQLInterfaceType, + GraphQLNonNull, GraphQLObjectType, GraphQLScalarType, GraphQLUnionType, ) -from ariadne_codegen.exceptions import ParsingError -from ariadne_codegen.utils import str_to_snake_case - from ..codegen import ( generate_annotation_name, generate_arg, @@ -21,6 +19,7 @@ generate_call, generate_class_def, generate_constant, + generate_dict, generate_import_from, generate_keyword, generate_method_definition, @@ -28,7 +27,10 @@ generate_name, generate_return, ) +from ..exceptions import ParsingError from ..plugins.manager import PluginManager +from ..utils import process_name, str_to_snake_case +from .arguments import ArgumentsGenerator from .constants import ( ANY, BASE_MODEL_FILE_PATH, @@ -39,7 +41,7 @@ TYPING_MODULE, UPLOAD_CLASS_NAME, ) -from .scalars import ScalarData +from .scalars import ScalarData, generate_scalar_imports from .utils import get_final_type @@ -49,9 +51,11 @@ def __init__( graphql_fields: GraphQLFieldMap, name: str, base_name: str, + arguments_generator: ArgumentsGenerator, enums_module_name: str = "enums", custom_scalars: Optional[Dict[str, ScalarData]] = None, plugin_manager: Optional[PluginManager] = None, + convert_to_snake_case: bool = True, ) -> None: self.graphql_fields = graphql_fields self.name = name @@ -59,6 +63,9 @@ def __init__( self.enums_module_name = enums_module_name self.plugin_manager = plugin_manager self.custom_scalars = custom_scalars if custom_scalars else {} + self._used_custom_scalars: List[str] = [] + self.arguments_generator = arguments_generator + self.convert_to_snake_case = convert_to_snake_case self._imports: List[ast.ImportFrom] = [] self._type_imports: List[ast.ImportFrom] = [] @@ -70,10 +77,8 @@ def __init__( def generate(self) -> ast.Module: """Generate module with class definition of graphql client.""" - for name, field in self.graphql_fields.items(): final_type = get_final_type(field) - # if isinstance(final_type, GraphQLObjectType): method_def = self._generate_method( operation_name=name, operation_args=field.args, @@ -85,6 +90,8 @@ def generate(self) -> ast.Module: if not self._class_def.body: self._class_def.body.append(ast.Pass()) + self._add_custom_scalar_imports() + self._class_def.lineno = len(self._imports) + 3 module = generate_module( @@ -108,17 +115,8 @@ def _generate_method( final_type, ) -> ast.FunctionDef: arguments = self._generate_method_arguments(operation_args) - from_ = CUSTOM_FIELDS_TYPING_FILE_PATH.stem - if isinstance(final_type, GraphQLObjectType): - return_type_name = f"{final_type.name}Fields" - from_ = CUSTOM_FIELDS_FILE_PATH.stem - elif isinstance(final_type, GraphQLInterfaceType): - return_type_name = f"{final_type.name}Interface" - from_ = CUSTOM_FIELDS_FILE_PATH.stem - elif isinstance(final_type, GraphQLUnionType): - return_type_name = f"{final_type.name}Union" - else: - return_type_name = "GraphQLField" + return_type_name, from_ = self._get_return_type_and_from(final_type) + self._type_imports.append( generate_import_from( from_=from_, @@ -143,9 +141,11 @@ def _generate_method( def _generate_method_arguments(self, operation_args): cls_arg = generate_arg(name="cls") - kw_only_args, kw_defaults = self._generate_kw_args_and_defaults(operation_args) + kw_only_args, kw_defaults, args = self._generate_kw_args_and_defaults( + operation_args, + ) return generate_arguments( - args=[cls_arg], + args=[cls_arg, *args], kwonlyargs=kw_only_args, kw_defaults=kw_defaults, ) @@ -153,18 +153,37 @@ def _generate_method_arguments(self, operation_args): def _generate_kw_args_and_defaults(self, operation_args): kw_only_args = [] kw_defaults = [] + args = [] for arg_name, arg_type in operation_args.items(): + arg_name = process_name( + arg_name, + convert_to_snake_case=self.convert_to_snake_case, + ) arg_final_type = get_final_type(arg_type) - annotation, _ = self._parse_graphql_type_name(arg_final_type) - kw_only_args.append(generate_arg(name=arg_name, annotation=annotation)) - kw_defaults.append(generate_constant(value=None)) - return kw_only_args, kw_defaults + is_required = isinstance(arg_type.type, GraphQLNonNull) + annotation, _ = self._parse_graphql_type_name( + arg_final_type, + not is_required, + ) + arg = generate_arg(name=arg_name, annotation=annotation) + if is_required: + args.append(arg) + else: + kw_only_args.append(arg) + kw_defaults.append(generate_constant(value=None)) + return kw_only_args, kw_defaults, args def _generate_return_stmt(self, return_type_name, operation_name, operation_args): - keywords = [ - generate_keyword(arg=arg_name, value=generate_name(arg_name)) - for arg_name in operation_args - ] + arguments_dict = self._generate_arguments_dict(operation_args) + + arguments_keyword = generate_keyword( + arg="arguments", + value=generate_dict( + keys=arguments_dict.keys(), + values=arguments_dict.values(), + ), + ) + return generate_return( value=generate_call( func=generate_name(return_type_name), @@ -173,11 +192,42 @@ def _generate_return_stmt(self, return_type_name, operation_name, operation_args generate_keyword( arg="field_name", value=generate_constant(value=operation_name) ), - *keywords, + arguments_keyword, ], ) ) + def _generate_arguments_dict(self, operation_args): + arguments_dict = {} + for arg_name, arg_value in operation_args.items(): + final_type = get_final_type(arg_value) + is_required = isinstance(arg_value.type, GraphQLNonNull) + constant_value = f"{final_type.name}!" if is_required else final_type.name + arguments_dict[generate_constant(arg_name)] = generate_dict( + keys=[generate_constant("type"), generate_constant("value")], + values=[ + generate_constant(constant_value), + self._get_dict_value(arg_name, arg_value), + ], + ) + return arguments_dict + + def _get_dict_value(self, name: str, arg_value) -> Union[ast.Name, ast.Call]: + name = process_name( + name, + convert_to_snake_case=self.convert_to_snake_case, + ) + _, used_custom_scalar = self._parse_graphql_type_name(get_final_type(arg_value)) + if used_custom_scalar: + self._used_custom_scalars.append(used_custom_scalar) + scalar_data = self.custom_scalars[used_custom_scalar] + if scalar_data.serialize_name: + return generate_call( + func=generate_name(scalar_data.serialize_name), + args=[generate_name(name)], + ) + return generate_name(name) + def _parse_graphql_type_name( self, type_, nullable: bool = True ) -> Tuple[Union[ast.Name, ast.Subscript], Optional[str]]: @@ -194,7 +244,13 @@ def _parse_graphql_type_name( ) ) elif isinstance(type_, GraphQLEnumType): - self._add_import(generate_import_from(names=[name], level=1)) + self._add_import( + generate_import_from( + names=[name], + from_=self.enums_module_name, + level=1, + ) + ) elif isinstance(type_, GraphQLScalarType): if name not in self.custom_scalars: name = INPUT_SCALARS_MAP.get(name, ANY) @@ -209,11 +265,33 @@ def _parse_graphql_type_name( else: used_custom_scalar = name name = self.custom_scalars[name].type_name + self._used_custom_scalars.append(used_custom_scalar) else: raise ParsingError(f"Incorrect argument type {name}") return generate_annotation_name(name, nullable), used_custom_scalar + def _get_return_type_and_from(self, final_type): + if isinstance(final_type, GraphQLObjectType): + return_type_name = f"{final_type.name}Fields" + from_ = CUSTOM_FIELDS_FILE_PATH.stem + elif isinstance(final_type, GraphQLInterfaceType): + return_type_name = f"{final_type.name}Interface" + from_ = CUSTOM_FIELDS_FILE_PATH.stem + elif isinstance(final_type, GraphQLUnionType): + return_type_name = f"{final_type.name}Union" + from_ = CUSTOM_FIELDS_TYPING_FILE_PATH.stem + else: + return_type_name = "GraphQLField" + from_ = CUSTOM_FIELDS_TYPING_FILE_PATH.stem + return return_type_name, from_ + + def _add_custom_scalar_imports(self): + for custom_scalar_name in self._used_custom_scalars: + scalar_data = self.custom_scalars[custom_scalar_name] + for import_ in generate_scalar_imports(scalar_data): + self._add_import(import_) + @staticmethod def _capitalize_first_letter(s: str) -> str: return s[0].upper() + s[1:] diff --git a/ariadne_codegen/client_generators/dependencies/base_operation.py b/ariadne_codegen/client_generators/dependencies/base_operation.py index a488cc73..1396512c 100644 --- a/ariadne_codegen/client_generators/dependencies/base_operation.py +++ b/ariadne_codegen/client_generators/dependencies/base_operation.py @@ -2,91 +2,79 @@ from graphql import ( ArgumentNode, - BooleanValueNode, FieldNode, - FloatValueNode, InlineFragmentNode, - IntValueNode, NamedTypeNode, NameNode, - ObjectFieldNode, - ObjectValueNode, SelectionSetNode, - StringValueNode, + VariableNode, ) -from .base_model import BaseModel - class GraphQLArgument: - def __init__(self, argument_name: str, value: Any): + def __init__(self, argument_name: str): self._name = argument_name - self._value = self._convert_value(value) - - def _convert_value( - self, value: Any - ) -> Union[ - StringValueNode, IntValueNode, FloatValueNode, BooleanValueNode, ObjectValueNode - ]: - if isinstance(value, str): - return StringValueNode(value=value) - if isinstance(value, int): - return IntValueNode(value=str(value)) - if isinstance(value, float): - return FloatValueNode(value=str(value)) - if isinstance(value, bool): - return BooleanValueNode(value=value) - if isinstance(value, BaseModel): - fields = [ - ObjectFieldNode(name=NameNode(value=k), value=self._convert_value(v)) - for k, v in value.model_dump().items() - ] - return ObjectValueNode(fields=fields) - raise TypeError(f"Unsupported argument type: {type(value)}") + self._variable_name = argument_name - def to_ast(self) -> ArgumentNode: - return ArgumentNode(name=NameNode(value=self._name), value=self._value) + def to_ast(self, idx: int) -> ArgumentNode: + return ArgumentNode( + name=NameNode(value=self._name), + value=VariableNode(name=NameNode(value=f"{idx}_{self._variable_name}")), + ) class GraphQLField: - def __init__(self, field_name: str, **kwargs: Any) -> None: - self._field_name: str = field_name - self._arguments: List[GraphQLArgument] = [ - GraphQLArgument(k, v) for k, v in kwargs.items() if v - ] - self._subfields: List["GraphQLField"] = [] + def __init__( + self, field_name: str, arguments: Optional[Dict[str, Any]] = None + ) -> None: + self._field_name = field_name + self._variables = arguments or {} + self._arguments = [GraphQLArgument(k) for k in self._variables] + self._subfields: List[GraphQLField] = [] self._alias: Optional[str] = None - self._inline_fragments: Dict[str, Tuple["GraphQLField", ...]] = {} + self._inline_fragments: Dict[str, Tuple[GraphQLField, ...]] = {} + + def get_variables_types(self, idx: int) -> Dict[str, Any]: + return {f"{idx}_{k}": v["type"] for k, v in self._variables.items()} + + def get_processed_variables(self, idx: int) -> Dict[str, Any]: + return {f"{idx}_{k}": v["value"] for k, v in self._variables.items()} def alias(self, alias: str) -> "GraphQLField": self._alias = alias return self + def add_subfield(self, subfield: "GraphQLField") -> None: + self._subfields.append(subfield) + + def add_inline_fragment(self, type_name: str, *subfields: "GraphQLField") -> None: + self._inline_fragments[type_name] = subfields + def _build_field_name(self) -> str: - if self._alias: - return f"{self._alias}: {self._field_name}" - return self._field_name + return f"{self._alias}: {self._field_name}" if self._alias else self._field_name - def to_ast(self) -> FieldNode: + def _build_selections(self, idx: int) -> List[Union[FieldNode, InlineFragmentNode]]: selections: List[Union[FieldNode, InlineFragmentNode]] = [ - sub_field.to_ast() for sub_field in self._subfields + subfield.to_ast(idx) for subfield in self._subfields ] - if self._inline_fragments: - selections.extend( - [ - InlineFragmentNode( - type_condition=NamedTypeNode(name=NameNode(value=name)), - selection_set=SelectionSetNode( - selections=[sub_field.to_ast() for sub_field in subfields] - ), - ) - for name, subfields in self._inline_fragments.items() - ] + for name, subfields in self._inline_fragments.items(): + selections.append( + InlineFragmentNode( + type_condition=NamedTypeNode(name=NameNode(value=name)), + selection_set=SelectionSetNode( + selections=[subfield.to_ast(idx) for subfield in subfields] + ), + ) ) + return selections + + def to_ast(self, idx: int) -> FieldNode: return FieldNode( name=NameNode(value=self._build_field_name()), - arguments=[arg.to_ast() for arg in self._arguments], + arguments=[arg.to_ast(idx) for arg in self._arguments], selection_set=( - SelectionSetNode(selections=selections) if selections else None + SelectionSetNode(selections=self._build_selections(idx)) + if self._subfields or self._inline_fragments + else None ), ) diff --git a/ariadne_codegen/client_generators/package.py b/ariadne_codegen/client_generators/package.py index a49901c3..b3adb052 100644 --- a/ariadne_codegen/client_generators/package.py +++ b/ariadne_codegen/client_generators/package.py @@ -460,7 +460,9 @@ def get_package_generator( custom_scalars=settings.scalars, plugin_manager=plugin_manager, ) - custom_fields_generator = CustomFieldsGenerator(schema=schema) + custom_fields_generator = CustomFieldsGenerator( + schema=schema, custom_scalars=settings.scalars, plugin_manager=plugin_manager + ) custom_fields_typing_generator = CustomFieldsTypingGenerator(schema=schema) custom_query_generator = None if schema.query_type: @@ -471,6 +473,12 @@ def get_package_generator( enums_module_name=settings.enums_module_name, custom_scalars=settings.scalars, plugin_manager=plugin_manager, + arguments_generator=ArgumentsGenerator( + schema=schema, + convert_to_snake_case=settings.convert_to_snake_case, + custom_scalars=settings.scalars, + plugin_manager=plugin_manager, + ), ) custom_mutation_generator = None if schema.mutation_type: @@ -481,6 +489,12 @@ def get_package_generator( enums_module_name=settings.enums_module_name, custom_scalars=settings.scalars, plugin_manager=plugin_manager, + arguments_generator=ArgumentsGenerator( + schema=schema, + convert_to_snake_case=settings.convert_to_snake_case, + custom_scalars=settings.scalars, + plugin_manager=plugin_manager, + ), ) return PackageGenerator( diff --git a/pyproject.toml b/pyproject.toml index 81ab64dc..70a91ae4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "hatchling.build" name = "ariadne-codegen" description = "Generate fully typed GraphQL client from schema, queries and mutations!" authors = [{ name = "Mirumee Software", email = "hello@mirumee.com" }] -version = "0.14.0.dev1" +version = "0.14.0.dev2" readme = "README.md" license = { file = "LICENSE" } classifiers = [ diff --git a/tests/main/clients/custom_query_builder/expected_client/base_operation.py b/tests/main/clients/custom_query_builder/expected_client/base_operation.py index a488cc73..1396512c 100644 --- a/tests/main/clients/custom_query_builder/expected_client/base_operation.py +++ b/tests/main/clients/custom_query_builder/expected_client/base_operation.py @@ -2,91 +2,79 @@ from graphql import ( ArgumentNode, - BooleanValueNode, FieldNode, - FloatValueNode, InlineFragmentNode, - IntValueNode, NamedTypeNode, NameNode, - ObjectFieldNode, - ObjectValueNode, SelectionSetNode, - StringValueNode, + VariableNode, ) -from .base_model import BaseModel - class GraphQLArgument: - def __init__(self, argument_name: str, value: Any): + def __init__(self, argument_name: str): self._name = argument_name - self._value = self._convert_value(value) - - def _convert_value( - self, value: Any - ) -> Union[ - StringValueNode, IntValueNode, FloatValueNode, BooleanValueNode, ObjectValueNode - ]: - if isinstance(value, str): - return StringValueNode(value=value) - if isinstance(value, int): - return IntValueNode(value=str(value)) - if isinstance(value, float): - return FloatValueNode(value=str(value)) - if isinstance(value, bool): - return BooleanValueNode(value=value) - if isinstance(value, BaseModel): - fields = [ - ObjectFieldNode(name=NameNode(value=k), value=self._convert_value(v)) - for k, v in value.model_dump().items() - ] - return ObjectValueNode(fields=fields) - raise TypeError(f"Unsupported argument type: {type(value)}") + self._variable_name = argument_name - def to_ast(self) -> ArgumentNode: - return ArgumentNode(name=NameNode(value=self._name), value=self._value) + def to_ast(self, idx: int) -> ArgumentNode: + return ArgumentNode( + name=NameNode(value=self._name), + value=VariableNode(name=NameNode(value=f"{idx}_{self._variable_name}")), + ) class GraphQLField: - def __init__(self, field_name: str, **kwargs: Any) -> None: - self._field_name: str = field_name - self._arguments: List[GraphQLArgument] = [ - GraphQLArgument(k, v) for k, v in kwargs.items() if v - ] - self._subfields: List["GraphQLField"] = [] + def __init__( + self, field_name: str, arguments: Optional[Dict[str, Any]] = None + ) -> None: + self._field_name = field_name + self._variables = arguments or {} + self._arguments = [GraphQLArgument(k) for k in self._variables] + self._subfields: List[GraphQLField] = [] self._alias: Optional[str] = None - self._inline_fragments: Dict[str, Tuple["GraphQLField", ...]] = {} + self._inline_fragments: Dict[str, Tuple[GraphQLField, ...]] = {} + + def get_variables_types(self, idx: int) -> Dict[str, Any]: + return {f"{idx}_{k}": v["type"] for k, v in self._variables.items()} + + def get_processed_variables(self, idx: int) -> Dict[str, Any]: + return {f"{idx}_{k}": v["value"] for k, v in self._variables.items()} def alias(self, alias: str) -> "GraphQLField": self._alias = alias return self + def add_subfield(self, subfield: "GraphQLField") -> None: + self._subfields.append(subfield) + + def add_inline_fragment(self, type_name: str, *subfields: "GraphQLField") -> None: + self._inline_fragments[type_name] = subfields + def _build_field_name(self) -> str: - if self._alias: - return f"{self._alias}: {self._field_name}" - return self._field_name + return f"{self._alias}: {self._field_name}" if self._alias else self._field_name - def to_ast(self) -> FieldNode: + def _build_selections(self, idx: int) -> List[Union[FieldNode, InlineFragmentNode]]: selections: List[Union[FieldNode, InlineFragmentNode]] = [ - sub_field.to_ast() for sub_field in self._subfields + subfield.to_ast(idx) for subfield in self._subfields ] - if self._inline_fragments: - selections.extend( - [ - InlineFragmentNode( - type_condition=NamedTypeNode(name=NameNode(value=name)), - selection_set=SelectionSetNode( - selections=[sub_field.to_ast() for sub_field in subfields] - ), - ) - for name, subfields in self._inline_fragments.items() - ] + for name, subfields in self._inline_fragments.items(): + selections.append( + InlineFragmentNode( + type_condition=NamedTypeNode(name=NameNode(value=name)), + selection_set=SelectionSetNode( + selections=[subfield.to_ast(idx) for subfield in subfields] + ), + ) ) + return selections + + def to_ast(self, idx: int) -> FieldNode: return FieldNode( name=NameNode(value=self._build_field_name()), - arguments=[arg.to_ast() for arg in self._arguments], + arguments=[arg.to_ast(idx) for arg in self._arguments], selection_set=( - SelectionSetNode(selections=selections) if selections else None + SelectionSetNode(selections=self._build_selections(idx)) + if self._subfields or self._inline_fragments + else None ), ) diff --git a/tests/main/clients/custom_query_builder/expected_client/client.py b/tests/main/clients/custom_query_builder/expected_client/client.py index 4e47fc3b..53148b23 100644 --- a/tests/main/clients/custom_query_builder/expected_client/client.py +++ b/tests/main/clients/custom_query_builder/expected_client/client.py @@ -1,11 +1,14 @@ -from typing import Any, Dict +from typing import Any, Dict, List, Tuple from graphql import ( DocumentNode, + NamedTypeNode, NameNode, OperationDefinitionNode, OperationType, SelectionSetNode, + VariableDefinitionNode, + VariableNode, print_ast, ) @@ -21,21 +24,64 @@ class Client(AsyncBaseClient): async def execute_custom_operation( self, *fields: GraphQLField, operation_type: OperationType, operation_name: str ) -> Dict[str, Any]: - operation_ast = DocumentNode( + variables_types_combined, processed_variables_combined = ( + self._combine_variables(fields) + ) + variable_definitions = self._build_variable_definitions( + variables_types_combined + ) + operation_ast = self._build_operation_ast( + fields, operation_type, operation_name, variable_definitions + ) + response = await self.execute( + print_ast(operation_ast), + variables=processed_variables_combined, + operation_name=operation_name, + ) + return self.get_data(response) + + def _combine_variables( + self, fields: Tuple[GraphQLField, ...] + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + variables_types_combined = {} + processed_variables_combined = {} + for idx, field in enumerate(fields): + variables_types_combined.update(field.get_variables_types(idx)) + processed_variables_combined.update(field.get_processed_variables(idx)) + return (variables_types_combined, processed_variables_combined) + + def _build_variable_definitions( + self, variables_types_combined: Dict[str, str] + ) -> List[VariableDefinitionNode]: + return [ + VariableDefinitionNode( + variable=VariableNode(name=NameNode(value=var_name)), + type=NamedTypeNode(name=NameNode(value=var_value)), + ) + for var_name, var_value in variables_types_combined.items() + ] + + def _build_operation_ast( + self, + fields: Tuple[GraphQLField, ...], + operation_type: OperationType, + operation_name: str, + variable_definitions: List[VariableDefinitionNode], + ) -> DocumentNode: + return DocumentNode( definitions=[ OperationDefinitionNode( operation=operation_type, name=NameNode(value=operation_name), + variable_definitions=variable_definitions, selection_set=SelectionSetNode( - selections=[field.to_ast() for field in fields] + selections=[ + field.to_ast(idx) for idx, field in enumerate(fields) + ] ), ) ] ) - response = await self.execute( - print_ast(operation_ast), operation_name=operation_name - ) - return self.get_data(response) async def query(self, *fields: GraphQLField, operation_name: str) -> Dict[str, Any]: return await self.execute_custom_operation( diff --git a/tests/main/clients/custom_query_builder/expected_client/custom_fields.py b/tests/main/clients/custom_query_builder/expected_client/custom_fields.py index 7bb33e07..310c5d8b 100644 --- a/tests/main/clients/custom_query_builder/expected_client/custom_fields.py +++ b/tests/main/clients/custom_query_builder/expected_client/custom_fields.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Any, Optional, Union from . import ( AppGraphQLField, @@ -77,23 +77,23 @@ def fields(self, *subfields: MetadataItemGraphQLField) -> "MetadataItemFields": class ObjectWithMetadataInterface(GraphQLField): @classmethod def private_metadata(cls) -> "MetadataItemFields": - return MetadataItemFields("private_metadata") + return MetadataItemFields("private_metadata", arguments={}) @classmethod - def private_metafield( - cls, *, key: Optional[str] = None - ) -> "ObjectWithMetadataGraphQLField": - return ObjectWithMetadataGraphQLField("private_metafield", key=key) + def private_metafield(cls, key: str) -> "ObjectWithMetadataGraphQLField": + return ObjectWithMetadataGraphQLField( + "private_metafield", arguments={"key": {"type": "String!", "value": key}} + ) @classmethod def metadata(cls) -> "MetadataItemFields": - return MetadataItemFields("metadata") + return MetadataItemFields("metadata", arguments={}) @classmethod - def metafield( - cls, *, key: Optional[str] = None - ) -> "ObjectWithMetadataGraphQLField": - return ObjectWithMetadataGraphQLField("metafield", key=key) + def metafield(cls, key: str) -> "ObjectWithMetadataGraphQLField": + return ObjectWithMetadataGraphQLField( + "metafield", arguments={"key": {"type": "String!", "value": key}} + ) def fields( self, *subfields: Union[ObjectWithMetadataGraphQLField, "MetadataItemFields"] @@ -126,19 +126,23 @@ class ProductFields(GraphQLField): @classmethod def private_metadata(cls) -> "MetadataItemFields": - return MetadataItemFields("private_metadata") + return MetadataItemFields("private_metadata", arguments={}) @classmethod - def private_metafield(cls, *, key: Optional[str] = None) -> "ProductGraphQLField": - return ProductGraphQLField("private_metafield", key=key) + def private_metafield(cls, key: str) -> "ProductGraphQLField": + return ProductGraphQLField( + "private_metafield", arguments={"key": {"type": "String!", "value": key}} + ) @classmethod def metadata(cls) -> "MetadataItemFields": - return MetadataItemFields("metadata") + return MetadataItemFields("metadata", arguments={}) @classmethod - def metafield(cls, *, key: Optional[str] = None) -> "ProductGraphQLField": - return ProductGraphQLField("metafield", key=key) + def metafield(cls, key: str) -> "ProductGraphQLField": + return ProductGraphQLField( + "metafield", arguments={"key": {"type": "String!", "value": key}} + ) def fields( self, *subfields: Union[ProductGraphQLField, "MetadataItemFields"] @@ -150,11 +154,11 @@ def fields( class ProductCountableConnectionFields(GraphQLField): @classmethod def edges(cls) -> "ProductCountableEdgeFields": - return ProductCountableEdgeFields("edges") + return ProductCountableEdgeFields("edges", arguments={}) @classmethod def page_info(cls) -> "PageInfoFields": - return PageInfoFields("page_info") + return PageInfoFields("page_info", arguments={}) total_count: ProductCountableConnectionGraphQLField = ( ProductCountableConnectionGraphQLField("totalCount") @@ -175,7 +179,7 @@ def fields( class ProductCountableEdgeFields(GraphQLField): @classmethod def node(cls) -> "ProductFields": - return ProductFields("node") + return ProductFields("node", arguments={}) cursor: ProductCountableEdgeGraphQLField = ProductCountableEdgeGraphQLField( "cursor" @@ -218,7 +222,7 @@ def fields( class ProductTypeCountableConnectionFields(GraphQLField): @classmethod def page_info(cls) -> "PageInfoFields": - return PageInfoFields("page_info") + return PageInfoFields("page_info", arguments={}) def fields( self, @@ -231,11 +235,11 @@ def fields( class TranslatableItemConnectionFields(GraphQLField): @classmethod def page_info(cls) -> "PageInfoFields": - return PageInfoFields("page_info") + return PageInfoFields("page_info", arguments={}) @classmethod def edges(cls) -> "TranslatableItemEdgeFields": - return TranslatableItemEdgeFields("edges") + return TranslatableItemEdgeFields("edges", arguments={}) total_count: TranslatableItemConnectionGraphQLField = ( TranslatableItemConnectionGraphQLField("totalCount") @@ -270,15 +274,15 @@ def fields( class UpdateMetadataFields(GraphQLField): @classmethod def metadata_errors(cls) -> "MetadataErrorFields": - return MetadataErrorFields("metadata_errors") + return MetadataErrorFields("metadata_errors", arguments={}) @classmethod def errors(cls) -> "MetadataErrorFields": - return MetadataErrorFields("errors") + return MetadataErrorFields("errors", arguments={}) @classmethod def item(cls) -> "ObjectWithMetadataInterface": - return ObjectWithMetadataInterface("item") + return ObjectWithMetadataInterface("item", arguments={}) def fields( self, diff --git a/tests/main/clients/custom_query_builder/expected_client/custom_mutations.py b/tests/main/clients/custom_query_builder/expected_client/custom_mutations.py index f21dfa8b..a4937a74 100644 --- a/tests/main/clients/custom_query_builder/expected_client/custom_mutations.py +++ b/tests/main/clients/custom_query_builder/expected_client/custom_mutations.py @@ -5,5 +5,7 @@ class Mutation: @classmethod - def update_metadata(cls, *, id: Optional[str] = None) -> UpdateMetadataFields: - return UpdateMetadataFields(field_name="updateMetadata", id=id) + def update_metadata(cls, id: str) -> UpdateMetadataFields: + return UpdateMetadataFields( + field_name="updateMetadata", arguments={"id": {"type": "ID!", "value": id}} + ) diff --git a/tests/main/clients/custom_query_builder/expected_client/custom_queries.py b/tests/main/clients/custom_query_builder/expected_client/custom_queries.py index d45f66d1..f28580fb 100644 --- a/tests/main/clients/custom_query_builder/expected_client/custom_queries.py +++ b/tests/main/clients/custom_query_builder/expected_client/custom_queries.py @@ -14,16 +14,22 @@ def products( cls, *, channel: Optional[str] = None, first: Optional[int] = None ) -> ProductCountableConnectionFields: return ProductCountableConnectionFields( - field_name="products", channel=channel, first=first + field_name="products", + arguments={ + "channel": {"type": "String", "value": channel}, + "first": {"type": "Int", "value": first}, + }, ) @classmethod def app(cls) -> AppFields: - return AppFields(field_name="app") + return AppFields(field_name="app", arguments={}) @classmethod def product_types(cls) -> ProductTypeCountableConnectionFields: - return ProductTypeCountableConnectionFields(field_name="productTypes") + return ProductTypeCountableConnectionFields( + field_name="productTypes", arguments={} + ) @classmethod def translations( @@ -36,8 +42,10 @@ def translations( ) -> TranslatableItemConnectionFields: return TranslatableItemConnectionFields( field_name="translations", - before=before, - after=after, - first=first, - last=last, + arguments={ + "before": {"type": "String", "value": before}, + "after": {"type": "String", "value": after}, + "first": {"type": "Int", "value": first}, + "last": {"type": "Int", "value": last}, + }, ) From 0d6f38303af963c1a9fcac328b9660f6fa1b3962 Mon Sep 17 00:00:00 2001 From: Damian Czajkowski Date: Wed, 10 Jul 2024 12:23:00 +0200 Subject: [PATCH 04/11] refactor custom operation file --- .../client_generators/custom_operation.py | 189 ++++++++++-------- 1 file changed, 104 insertions(+), 85 deletions(-) diff --git a/ariadne_codegen/client_generators/custom_operation.py b/ariadne_codegen/client_generators/custom_operation.py index d0545ad6..f23e13fc 100644 --- a/ariadne_codegen/client_generators/custom_operation.py +++ b/ariadne_codegen/client_generators/custom_operation.py @@ -114,119 +114,131 @@ def _generate_method( operation_args, final_type, ) -> ast.FunctionDef: - arguments = self._generate_method_arguments(operation_args) - return_type_name, from_ = self._get_return_type_and_from(final_type) - - self._type_imports.append( - generate_import_from( - from_=from_, - names=[return_type_name], - level=1, - ) - ) + method_arguments, return_arguments = self._generate_arguments(operation_args) + return_type_name = self._get_return_type_and_from(final_type) return generate_method_definition( name=str_to_snake_case(operation_name), - arguments=arguments, + arguments=method_arguments, return_type=generate_name(return_type_name), body=[ - self._generate_return_stmt( - return_type_name, - operation_name, - operation_args, + generate_return( + value=generate_call( + func=generate_name(return_type_name), + args=[], + keywords=[ + generate_keyword( + arg="field_name", + value=generate_constant(value=operation_name), + ), + return_arguments, + ], + ) ) ], decorator_list=[generate_name("classmethod")], ) - def _generate_method_arguments(self, operation_args): + def _generate_arguments(self, operation_args): cls_arg = generate_arg(name="cls") - kw_only_args, kw_defaults, args = self._generate_kw_args_and_defaults( - operation_args, - ) - return generate_arguments( - args=[cls_arg, *args], - kwonlyargs=kw_only_args, - kw_defaults=kw_defaults, - ) + args, kw_only_args, kw_defaults = [], [], [] + return_arguments_keys, return_arguments_values = [], [] - def _generate_kw_args_and_defaults(self, operation_args): - kw_only_args = [] - kw_defaults = [] - args = [] - for arg_name, arg_type in operation_args.items(): - arg_name = process_name( - arg_name, - convert_to_snake_case=self.convert_to_snake_case, - ) - arg_final_type = get_final_type(arg_type) - is_required = isinstance(arg_type.type, GraphQLNonNull) - annotation, _ = self._parse_graphql_type_name( - arg_final_type, - not is_required, + for arg_name, arg_value in operation_args.items(): + final_type = get_final_type(arg_value) + is_required = isinstance(arg_value.type, GraphQLNonNull) + name = self._process_argument_name(arg_name) + annotation, used_custom_scalar = self._parse_graphql_type_name( + final_type, not is_required ) - arg = generate_arg(name=arg_name, annotation=annotation) - if is_required: - args.append(arg) - else: - kw_only_args.append(arg) - kw_defaults.append(generate_constant(value=None)) - return kw_only_args, kw_defaults, args - def _generate_return_stmt(self, return_type_name, operation_name, operation_args): - arguments_dict = self._generate_arguments_dict(operation_args) + self._accumulate_method_arguments( + args, kw_only_args, kw_defaults, arg_name, annotation, is_required + ) + self._accumulate_return_arguments( + return_arguments_keys, + return_arguments_values, + arg_name, + name, + final_type, + is_required, + used_custom_scalar, + ) - arguments_keyword = generate_keyword( - arg="arguments", - value=generate_dict( - keys=arguments_dict.keys(), - values=arguments_dict.values(), - ), + method_arguments = self._assemble_method_arguments( + cls_arg, args, kw_only_args, kw_defaults + ) + return_arguments = self._assemble_return_arguments( + return_arguments_keys, return_arguments_values ) - return generate_return( - value=generate_call( - func=generate_name(return_type_name), - args=[], - keywords=[ - generate_keyword( - arg="field_name", value=generate_constant(value=operation_name) - ), - arguments_keyword, - ], - ) + return method_arguments, return_arguments + + def _process_argument_name(self, arg_name): + return process_name(arg_name, convert_to_snake_case=self.convert_to_snake_case) + + def _accumulate_method_arguments( + self, args, kw_only_args, kw_defaults, arg_name, annotation, is_required + ): + if is_required: + args.append(generate_arg(name=arg_name, annotation=annotation)) + else: + kw_only_args.append(generate_arg(name=arg_name, annotation=annotation)) + kw_defaults.append(generate_constant(value=None)) + + def _accumulate_return_arguments( + self, + return_arguments_keys, + return_arguments_values, + arg_name, + name, + final_type, + is_required, + used_custom_scalar, + ): + constant_value = f"{final_type.name}!" if is_required else final_type.name + return_arg_dict_value = self._generate_return_arg_value( + name, + used_custom_scalar, ) - def _generate_arguments_dict(self, operation_args): - arguments_dict = {} - for arg_name, arg_value in operation_args.items(): - final_type = get_final_type(arg_value) - is_required = isinstance(arg_value.type, GraphQLNonNull) - constant_value = f"{final_type.name}!" if is_required else final_type.name - arguments_dict[generate_constant(arg_name)] = generate_dict( + return_arguments_keys.append(generate_constant(arg_name)) + return_arguments_values.append( + generate_dict( keys=[generate_constant("type"), generate_constant("value")], - values=[ - generate_constant(constant_value), - self._get_dict_value(arg_name, arg_value), - ], + values=[generate_constant(constant_value), return_arg_dict_value], ) - return arguments_dict - - def _get_dict_value(self, name: str, arg_value) -> Union[ast.Name, ast.Call]: - name = process_name( - name, - convert_to_snake_case=self.convert_to_snake_case, ) - _, used_custom_scalar = self._parse_graphql_type_name(get_final_type(arg_value)) + + def _generate_return_arg_value(self, name, used_custom_scalar): + return_arg_dict_value = generate_name(name) + if used_custom_scalar: self._used_custom_scalars.append(used_custom_scalar) scalar_data = self.custom_scalars[used_custom_scalar] if scalar_data.serialize_name: - return generate_call( + return_arg_dict_value = generate_call( func=generate_name(scalar_data.serialize_name), args=[generate_name(name)], ) - return generate_name(name) + + return return_arg_dict_value + + def _assemble_method_arguments(self, cls_arg, args, kw_only_args, kw_defaults): + return generate_arguments( + args=[cls_arg, *args], + kwonlyargs=kw_only_args, + kw_defaults=kw_defaults, + ) + + def _assemble_return_arguments(self, keys, values): + return generate_keyword( + arg="arguments", + value=generate_dict( + keys=keys, + values=values, + ), + ) def _parse_graphql_type_name( self, type_, nullable: bool = True @@ -284,7 +296,14 @@ def _get_return_type_and_from(self, final_type): else: return_type_name = "GraphQLField" from_ = CUSTOM_FIELDS_TYPING_FILE_PATH.stem - return return_type_name, from_ + self._type_imports.append( + generate_import_from( + from_=from_, + names=[return_type_name], + level=1, + ) + ) + return return_type_name def _add_custom_scalar_imports(self): for custom_scalar_name in self._used_custom_scalars: From f387fb7d421bde5e3859657a5ea7db5aa6375262 Mon Sep 17 00:00:00 2001 From: Damian Czajkowski Date: Thu, 11 Jul 2024 15:51:23 +0200 Subject: [PATCH 05/11] Add fix to argument duplication --- ariadne_codegen/client_generators/client.py | 48 +++--- .../client_generators/custom_operation.py | 16 +- .../dependencies/base_operation.py | 76 +++++++--- ariadne_codegen/client_generators/utils.py | 2 + .../expected_client/base_operation.py | 76 +++++++--- .../expected_client/client.py | 11 +- .../graphql_client/__init__.py | 4 +- .../graphql_client/base_operation.py | 142 +++++++++++------- .../graphql_client/client.py | 67 ++++++++- .../graphql_client/custom_fields.py | 39 +++-- .../graphql_client/custom_mutations.py | 67 +++++---- .../graphql_client/custom_queries.py | 59 +++----- .../graphql_client/custom_typing_fields.py | 12 +- .../graphql_client/enums.py | 3 +- .../graphql_client/input_types.py | 12 +- .../custom_operation_builder/schema.graphql | 98 ++++++++++++ .../test_operation_build.py | 139 +++++++++-------- 17 files changed, 569 insertions(+), 302 deletions(-) create mode 100644 tests/main/custom_operation_builder/schema.graphql diff --git a/ariadne_codegen/client_generators/client.py b/ariadne_codegen/client_generators/client.py index 7bc42a91..2d8c610c 100644 --- a/ariadne_codegen/client_generators/client.py +++ b/ariadne_codegen/client_generators/client.py @@ -223,17 +223,15 @@ def create_combine_variables_method(self): value=generate_dict(), ), ast.For( - target=generate_tuple( - elts=[ - generate_name("idx"), - generate_name("field"), - ], - ), - iter=generate_call( - func=generate_name("enumerate"), - args=[generate_name("fields")], - ), + target=generate_name("field"), + iter=generate_name("fields"), body=[ + generate_assign( + targets=["formatted_variables"], + value=generate_call( + func=generate_name("field.get_formatted_variables") + ), + ), generate_expr( value=generate_call( func=generate_attribute( @@ -241,12 +239,15 @@ def create_combine_variables_method(self): attr="update", ), args=[ - generate_call( - func=generate_attribute( - value=generate_name("field"), - attr="get_variables_types", - ), - args=[generate_name("idx")], + ast.DictComp( + key=generate_name("k"), + value=generate_name('v["type"]'), + generators=[ + generate_comp( + target="k, v", + iter_="formatted_variables.items()", + ) + ], ) ], ) @@ -258,12 +259,15 @@ def create_combine_variables_method(self): attr="update", ), args=[ - generate_call( - func=generate_attribute( - value=generate_name("field"), - attr="get_processed_variables", - ), - args=[generate_name("idx")], + ast.DictComp( + key=generate_name("k"), + value=generate_name('v["value"]'), + generators=[ + generate_comp( + target="k, v", + iter_="formatted_variables.items()", + ) + ], ) ], ) diff --git a/ariadne_codegen/client_generators/custom_operation.py b/ariadne_codegen/client_generators/custom_operation.py index f23e13fc..0fc9d102 100644 --- a/ariadne_codegen/client_generators/custom_operation.py +++ b/ariadne_codegen/client_generators/custom_operation.py @@ -147,13 +147,16 @@ def _generate_arguments(self, operation_args): for arg_name, arg_value in operation_args.items(): final_type = get_final_type(arg_value) is_required = isinstance(arg_value.type, GraphQLNonNull) - name = self._process_argument_name(arg_name) + name = process_name( + arg_name, + convert_to_snake_case=self.convert_to_snake_case, + ) annotation, used_custom_scalar = self._parse_graphql_type_name( final_type, not is_required ) self._accumulate_method_arguments( - args, kw_only_args, kw_defaults, arg_name, annotation, is_required + args, kw_only_args, kw_defaults, name, annotation, is_required ) self._accumulate_return_arguments( return_arguments_keys, @@ -174,16 +177,13 @@ def _generate_arguments(self, operation_args): return method_arguments, return_arguments - def _process_argument_name(self, arg_name): - return process_name(arg_name, convert_to_snake_case=self.convert_to_snake_case) - def _accumulate_method_arguments( - self, args, kw_only_args, kw_defaults, arg_name, annotation, is_required + self, args, kw_only_args, kw_defaults, name, annotation, is_required ): if is_required: - args.append(generate_arg(name=arg_name, annotation=annotation)) + args.append(generate_arg(name=name, annotation=annotation)) else: - kw_only_args.append(generate_arg(name=arg_name, annotation=annotation)) + kw_only_args.append(generate_arg(name=name, annotation=annotation)) kw_defaults.append(generate_constant(value=None)) def _accumulate_return_arguments( diff --git a/ariadne_codegen/client_generators/dependencies/base_operation.py b/ariadne_codegen/client_generators/dependencies/base_operation.py index 1396512c..1082ca4d 100644 --- a/ariadne_codegen/client_generators/dependencies/base_operation.py +++ b/ariadne_codegen/client_generators/dependencies/base_operation.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union from graphql import ( ArgumentNode, @@ -12,34 +12,28 @@ class GraphQLArgument: - def __init__(self, argument_name: str): + def __init__(self, argument_name: str, argument_value: Any): self._name = argument_name - self._variable_name = argument_name + self._value = argument_value - def to_ast(self, idx: int) -> ArgumentNode: + def to_ast(self) -> ArgumentNode: return ArgumentNode( name=NameNode(value=self._name), - value=VariableNode(name=NameNode(value=f"{idx}_{self._variable_name}")), + value=VariableNode(name=NameNode(value=self._value)), ) class GraphQLField: def __init__( - self, field_name: str, arguments: Optional[Dict[str, Any]] = None + self, field_name: str, arguments: Optional[Dict[str, Dict[str, Any]]] = None ) -> None: self._field_name = field_name self._variables = arguments or {} - self._arguments = [GraphQLArgument(k) for k in self._variables] + self.formatted_variables: Dict[str, Dict[str, Any]] = {} self._subfields: List[GraphQLField] = [] self._alias: Optional[str] = None self._inline_fragments: Dict[str, Tuple[GraphQLField, ...]] = {} - def get_variables_types(self, idx: int) -> Dict[str, Any]: - return {f"{idx}_{k}": v["type"] for k, v in self._variables.items()} - - def get_processed_variables(self, idx: int) -> Dict[str, Any]: - return {f"{idx}_{k}": v["value"] for k, v in self._variables.items()} - def alias(self, alias: str) -> "GraphQLField": self._alias = alias return self @@ -53,28 +47,72 @@ def add_inline_fragment(self, type_name: str, *subfields: "GraphQLField") -> Non def _build_field_name(self) -> str: return f"{self._alias}: {self._field_name}" if self._alias else self._field_name - def _build_selections(self, idx: int) -> List[Union[FieldNode, InlineFragmentNode]]: + def _build_selections( + self, idx: int, used_names: Set[str] + ) -> List[Union[FieldNode, InlineFragmentNode]]: selections: List[Union[FieldNode, InlineFragmentNode]] = [ - subfield.to_ast(idx) for subfield in self._subfields + subfield.to_ast(idx, used_names) for subfield in self._subfields ] for name, subfields in self._inline_fragments.items(): selections.append( InlineFragmentNode( type_condition=NamedTypeNode(name=NameNode(value=name)), selection_set=SelectionSetNode( - selections=[subfield.to_ast(idx) for subfield in subfields] + selections=[ + subfield.to_ast(idx, used_names) for subfield in subfields + ] ), ) ) return selections - def to_ast(self, idx: int) -> FieldNode: + def _format_variable_name( + self, idx: int, var_name: str, used_names: Set[str] + ) -> str: + base_name = f"{idx}_{var_name}" + unique_name = base_name + counter = 1 + while unique_name in used_names: + unique_name = f"{base_name}_{counter}" + counter += 1 + used_names.add(unique_name) + return unique_name + + def _collect_all_variables(self, idx: int, used_names: Set[str]) -> None: + self.formatted_variables = {} + for k, v in self._variables.items(): + unique_name = self._format_variable_name(idx, k, used_names) + self.formatted_variables[unique_name] = { + "name": k, + "type": v["type"], + "value": v["value"], + } + + def to_ast(self, idx: int, used_names: Optional[Set[str]] = None) -> FieldNode: + if used_names is None: + used_names = set() + self._collect_all_variables(idx, used_names) + formatted_args = [ + GraphQLArgument(v["name"], k).to_ast() + for k, v in self.formatted_variables.items() + ] return FieldNode( name=NameNode(value=self._build_field_name()), - arguments=[arg.to_ast(idx) for arg in self._arguments], + arguments=formatted_args, selection_set=( - SelectionSetNode(selections=self._build_selections(idx)) + SelectionSetNode(selections=self._build_selections(idx, used_names)) if self._subfields or self._inline_fragments else None ), ) + + def get_formatted_variables(self) -> Dict[str, Dict[str, Any]]: + formatted_variables = self.formatted_variables + for subfield in self._subfields: + subfield.get_formatted_variables() + self.formatted_variables.update(subfield.formatted_variables) + for subfields in self._inline_fragments.values(): + for subfield in subfields: + subfield.get_formatted_variables() + self.formatted_variables.update(subfield.formatted_variables) + return formatted_variables diff --git a/ariadne_codegen/client_generators/utils.py b/ariadne_codegen/client_generators/utils.py index 3540677f..81b59402 100644 --- a/ariadne_codegen/client_generators/utils.py +++ b/ariadne_codegen/client_generators/utils.py @@ -46,6 +46,8 @@ def _collect_dependent_types(self, graphql_type: GraphQLObjectType) -> None: stack.extend(subfield_type.types) for interface in current_type.interfaces: stack.append(interface) + elif isinstance(current_type, GraphQLUnionType): + stack.extend(current_type.types) def get_final_type(type_): diff --git a/tests/main/clients/custom_query_builder/expected_client/base_operation.py b/tests/main/clients/custom_query_builder/expected_client/base_operation.py index 1396512c..9538b34b 100644 --- a/tests/main/clients/custom_query_builder/expected_client/base_operation.py +++ b/tests/main/clients/custom_query_builder/expected_client/base_operation.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union from graphql import ( ArgumentNode, @@ -12,34 +12,28 @@ class GraphQLArgument: - def __init__(self, argument_name: str): + def __init__(self, argument_name: str, argument_value: Any): self._name = argument_name - self._variable_name = argument_name + self._value = argument_value - def to_ast(self, idx: int) -> ArgumentNode: + def to_ast(self) -> ArgumentNode: return ArgumentNode( name=NameNode(value=self._name), - value=VariableNode(name=NameNode(value=f"{idx}_{self._variable_name}")), + value=VariableNode(name=NameNode(value=self._value)), ) class GraphQLField: def __init__( - self, field_name: str, arguments: Optional[Dict[str, Any]] = None + self, field_name: str, arguments: Optional[Dict[str, Dict[str, Any]]] = None ) -> None: self._field_name = field_name self._variables = arguments or {} - self._arguments = [GraphQLArgument(k) for k in self._variables] + self._formatted_variables: Dict[str, Dict[str, Any]] = {} self._subfields: List[GraphQLField] = [] self._alias: Optional[str] = None self._inline_fragments: Dict[str, Tuple[GraphQLField, ...]] = {} - def get_variables_types(self, idx: int) -> Dict[str, Any]: - return {f"{idx}_{k}": v["type"] for k, v in self._variables.items()} - - def get_processed_variables(self, idx: int) -> Dict[str, Any]: - return {f"{idx}_{k}": v["value"] for k, v in self._variables.items()} - def alias(self, alias: str) -> "GraphQLField": self._alias = alias return self @@ -53,28 +47,72 @@ def add_inline_fragment(self, type_name: str, *subfields: "GraphQLField") -> Non def _build_field_name(self) -> str: return f"{self._alias}: {self._field_name}" if self._alias else self._field_name - def _build_selections(self, idx: int) -> List[Union[FieldNode, InlineFragmentNode]]: + def _build_selections( + self, idx: int, used_names: Set[str] + ) -> List[Union[FieldNode, InlineFragmentNode]]: selections: List[Union[FieldNode, InlineFragmentNode]] = [ - subfield.to_ast(idx) for subfield in self._subfields + subfield.to_ast(idx, used_names) for subfield in self._subfields ] for name, subfields in self._inline_fragments.items(): selections.append( InlineFragmentNode( type_condition=NamedTypeNode(name=NameNode(value=name)), selection_set=SelectionSetNode( - selections=[subfield.to_ast(idx) for subfield in subfields] + selections=[ + subfield.to_ast(idx, used_names) for subfield in subfields + ] ), ) ) return selections - def to_ast(self, idx: int) -> FieldNode: + def _format_variable_name( + self, idx: int, var_name: str, used_names: Set[str] + ) -> str: + base_name = f"{idx}_{var_name}" + unique_name = base_name + counter = 1 + while unique_name in used_names: + unique_name = f"{base_name}_{counter}" + counter += 1 + used_names.add(unique_name) + return unique_name + + def _collect_all_variables(self, idx: int, used_names: Set[str]) -> None: + self._formatted_variables = {} + for k, v in self._variables.items(): + unique_name = self._format_variable_name(idx, k, used_names) + self._formatted_variables[unique_name] = { + "name": k, + "type": v["type"], + "value": v["value"], + } + + def to_ast(self, idx: int, used_names: Optional[Set[str]] = None) -> FieldNode: + if used_names is None: + used_names = set() + self._collect_all_variables(idx, used_names) + formatted_args = [ + GraphQLArgument(v["name"], k).to_ast() + for k, v in self._formatted_variables.items() + ] return FieldNode( name=NameNode(value=self._build_field_name()), - arguments=[arg.to_ast(idx) for arg in self._arguments], + arguments=formatted_args, selection_set=( - SelectionSetNode(selections=self._build_selections(idx)) + SelectionSetNode(selections=self._build_selections(idx, used_names)) if self._subfields or self._inline_fragments else None ), ) + + def get_formatted_variables(self) -> Dict[str, Dict[str, Any]]: + formatted_variables = self._formatted_variables + for subfield in self._subfields: + subfield.get_formatted_variables() + self._formatted_variables.update(subfield._formatted_variables) + for subfields in self._inline_fragments.values(): + for subfield in subfields: + subfield.get_formatted_variables() + self._formatted_variables.update(subfield._formatted_variables) + return formatted_variables diff --git a/tests/main/clients/custom_query_builder/expected_client/client.py b/tests/main/clients/custom_query_builder/expected_client/client.py index 53148b23..6ba1aabf 100644 --- a/tests/main/clients/custom_query_builder/expected_client/client.py +++ b/tests/main/clients/custom_query_builder/expected_client/client.py @@ -45,9 +45,14 @@ def _combine_variables( ) -> Tuple[Dict[str, Any], Dict[str, Any]]: variables_types_combined = {} processed_variables_combined = {} - for idx, field in enumerate(fields): - variables_types_combined.update(field.get_variables_types(idx)) - processed_variables_combined.update(field.get_processed_variables(idx)) + for field in fields: + formatted_variables = field.get_formatted_variables() + variables_types_combined.update( + {k: v["type"] for k, v in formatted_variables.items()} + ) + processed_variables_combined.update( + {k: v["value"] for k, v in formatted_variables.items()} + ) return (variables_types_combined, processed_variables_combined) def _build_variable_definitions( diff --git a/tests/main/custom_operation_builder/graphql_client/__init__.py b/tests/main/custom_operation_builder/graphql_client/__init__.py index 9251b3d8..ac2385e2 100644 --- a/tests/main/custom_operation_builder/graphql_client/__init__.py +++ b/tests/main/custom_operation_builder/graphql_client/__init__.py @@ -1,6 +1,6 @@ from .async_base_client import AsyncBaseClient from .base_model import BaseModel, Upload -from .client import AutoGenClient +from .client import Client from .custom_typing_fields import ( AdminGraphQLField, GuestGraphQLField, @@ -23,8 +23,8 @@ "AddUserInput", "AdminGraphQLField", "AsyncBaseClient", - "AutoGenClient", "BaseModel", + "Client", "GraphQLClientError", "GraphQLClientGraphQLError", "GraphQLClientGraphQLMultiError", diff --git a/tests/main/custom_operation_builder/graphql_client/base_operation.py b/tests/main/custom_operation_builder/graphql_client/base_operation.py index a488cc73..1082ca4d 100644 --- a/tests/main/custom_operation_builder/graphql_client/base_operation.py +++ b/tests/main/custom_operation_builder/graphql_client/base_operation.py @@ -1,92 +1,118 @@ -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union from graphql import ( ArgumentNode, - BooleanValueNode, FieldNode, - FloatValueNode, InlineFragmentNode, - IntValueNode, NamedTypeNode, NameNode, - ObjectFieldNode, - ObjectValueNode, SelectionSetNode, - StringValueNode, + VariableNode, ) -from .base_model import BaseModel - class GraphQLArgument: - def __init__(self, argument_name: str, value: Any): + def __init__(self, argument_name: str, argument_value: Any): self._name = argument_name - self._value = self._convert_value(value) - - def _convert_value( - self, value: Any - ) -> Union[ - StringValueNode, IntValueNode, FloatValueNode, BooleanValueNode, ObjectValueNode - ]: - if isinstance(value, str): - return StringValueNode(value=value) - if isinstance(value, int): - return IntValueNode(value=str(value)) - if isinstance(value, float): - return FloatValueNode(value=str(value)) - if isinstance(value, bool): - return BooleanValueNode(value=value) - if isinstance(value, BaseModel): - fields = [ - ObjectFieldNode(name=NameNode(value=k), value=self._convert_value(v)) - for k, v in value.model_dump().items() - ] - return ObjectValueNode(fields=fields) - raise TypeError(f"Unsupported argument type: {type(value)}") + self._value = argument_value def to_ast(self) -> ArgumentNode: - return ArgumentNode(name=NameNode(value=self._name), value=self._value) + return ArgumentNode( + name=NameNode(value=self._name), + value=VariableNode(name=NameNode(value=self._value)), + ) class GraphQLField: - def __init__(self, field_name: str, **kwargs: Any) -> None: - self._field_name: str = field_name - self._arguments: List[GraphQLArgument] = [ - GraphQLArgument(k, v) for k, v in kwargs.items() if v - ] - self._subfields: List["GraphQLField"] = [] + def __init__( + self, field_name: str, arguments: Optional[Dict[str, Dict[str, Any]]] = None + ) -> None: + self._field_name = field_name + self._variables = arguments or {} + self.formatted_variables: Dict[str, Dict[str, Any]] = {} + self._subfields: List[GraphQLField] = [] self._alias: Optional[str] = None - self._inline_fragments: Dict[str, Tuple["GraphQLField", ...]] = {} + self._inline_fragments: Dict[str, Tuple[GraphQLField, ...]] = {} def alias(self, alias: str) -> "GraphQLField": self._alias = alias return self + def add_subfield(self, subfield: "GraphQLField") -> None: + self._subfields.append(subfield) + + def add_inline_fragment(self, type_name: str, *subfields: "GraphQLField") -> None: + self._inline_fragments[type_name] = subfields + def _build_field_name(self) -> str: - if self._alias: - return f"{self._alias}: {self._field_name}" - return self._field_name + return f"{self._alias}: {self._field_name}" if self._alias else self._field_name - def to_ast(self) -> FieldNode: + def _build_selections( + self, idx: int, used_names: Set[str] + ) -> List[Union[FieldNode, InlineFragmentNode]]: selections: List[Union[FieldNode, InlineFragmentNode]] = [ - sub_field.to_ast() for sub_field in self._subfields + subfield.to_ast(idx, used_names) for subfield in self._subfields ] - if self._inline_fragments: - selections.extend( - [ - InlineFragmentNode( - type_condition=NamedTypeNode(name=NameNode(value=name)), - selection_set=SelectionSetNode( - selections=[sub_field.to_ast() for sub_field in subfields] - ), - ) - for name, subfields in self._inline_fragments.items() - ] + for name, subfields in self._inline_fragments.items(): + selections.append( + InlineFragmentNode( + type_condition=NamedTypeNode(name=NameNode(value=name)), + selection_set=SelectionSetNode( + selections=[ + subfield.to_ast(idx, used_names) for subfield in subfields + ] + ), + ) ) + return selections + + def _format_variable_name( + self, idx: int, var_name: str, used_names: Set[str] + ) -> str: + base_name = f"{idx}_{var_name}" + unique_name = base_name + counter = 1 + while unique_name in used_names: + unique_name = f"{base_name}_{counter}" + counter += 1 + used_names.add(unique_name) + return unique_name + + def _collect_all_variables(self, idx: int, used_names: Set[str]) -> None: + self.formatted_variables = {} + for k, v in self._variables.items(): + unique_name = self._format_variable_name(idx, k, used_names) + self.formatted_variables[unique_name] = { + "name": k, + "type": v["type"], + "value": v["value"], + } + + def to_ast(self, idx: int, used_names: Optional[Set[str]] = None) -> FieldNode: + if used_names is None: + used_names = set() + self._collect_all_variables(idx, used_names) + formatted_args = [ + GraphQLArgument(v["name"], k).to_ast() + for k, v in self.formatted_variables.items() + ] return FieldNode( name=NameNode(value=self._build_field_name()), - arguments=[arg.to_ast() for arg in self._arguments], + arguments=formatted_args, selection_set=( - SelectionSetNode(selections=selections) if selections else None + SelectionSetNode(selections=self._build_selections(idx, used_names)) + if self._subfields or self._inline_fragments + else None ), ) + + def get_formatted_variables(self) -> Dict[str, Dict[str, Any]]: + formatted_variables = self.formatted_variables + for subfield in self._subfields: + subfield.get_formatted_variables() + self.formatted_variables.update(subfield.formatted_variables) + for subfields in self._inline_fragments.values(): + for subfield in subfields: + subfield.get_formatted_variables() + self.formatted_variables.update(subfield.formatted_variables) + return formatted_variables diff --git a/tests/main/custom_operation_builder/graphql_client/client.py b/tests/main/custom_operation_builder/graphql_client/client.py index 0dc1ac7c..6ba1aabf 100644 --- a/tests/main/custom_operation_builder/graphql_client/client.py +++ b/tests/main/custom_operation_builder/graphql_client/client.py @@ -1,11 +1,14 @@ -from typing import Any, Dict +from typing import Any, Dict, List, Tuple from graphql import ( DocumentNode, + NamedTypeNode, NameNode, OperationDefinitionNode, OperationType, SelectionSetNode, + VariableDefinitionNode, + VariableNode, print_ast, ) @@ -17,25 +20,73 @@ def gql(q: str) -> str: return q -class AutoGenClient(AsyncBaseClient): +class Client(AsyncBaseClient): async def execute_custom_operation( self, *fields: GraphQLField, operation_type: OperationType, operation_name: str ) -> Dict[str, Any]: - operation_ast = DocumentNode( + variables_types_combined, processed_variables_combined = ( + self._combine_variables(fields) + ) + variable_definitions = self._build_variable_definitions( + variables_types_combined + ) + operation_ast = self._build_operation_ast( + fields, operation_type, operation_name, variable_definitions + ) + response = await self.execute( + print_ast(operation_ast), + variables=processed_variables_combined, + operation_name=operation_name, + ) + return self.get_data(response) + + def _combine_variables( + self, fields: Tuple[GraphQLField, ...] + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + variables_types_combined = {} + processed_variables_combined = {} + for field in fields: + formatted_variables = field.get_formatted_variables() + variables_types_combined.update( + {k: v["type"] for k, v in formatted_variables.items()} + ) + processed_variables_combined.update( + {k: v["value"] for k, v in formatted_variables.items()} + ) + return (variables_types_combined, processed_variables_combined) + + def _build_variable_definitions( + self, variables_types_combined: Dict[str, str] + ) -> List[VariableDefinitionNode]: + return [ + VariableDefinitionNode( + variable=VariableNode(name=NameNode(value=var_name)), + type=NamedTypeNode(name=NameNode(value=var_value)), + ) + for var_name, var_value in variables_types_combined.items() + ] + + def _build_operation_ast( + self, + fields: Tuple[GraphQLField, ...], + operation_type: OperationType, + operation_name: str, + variable_definitions: List[VariableDefinitionNode], + ) -> DocumentNode: + return DocumentNode( definitions=[ OperationDefinitionNode( operation=operation_type, name=NameNode(value=operation_name), + variable_definitions=variable_definitions, selection_set=SelectionSetNode( - selections=[field.to_ast() for field in fields] + selections=[ + field.to_ast(idx) for idx, field in enumerate(fields) + ] ), ) ] ) - response = await self.execute( - print_ast(operation_ast), operation_name=operation_name - ) - return self.get_data(response) async def query(self, *fields: GraphQLField, operation_name: str) -> Dict[str, Any]: return await self.execute_custom_operation( diff --git a/tests/main/custom_operation_builder/graphql_client/custom_fields.py b/tests/main/custom_operation_builder/graphql_client/custom_fields.py index e0db87bc..e70a1c61 100644 --- a/tests/main/custom_operation_builder/graphql_client/custom_fields.py +++ b/tests/main/custom_operation_builder/graphql_client/custom_fields.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Union from . import ( AdminGraphQLField, @@ -13,13 +13,15 @@ class AdminFields(GraphQLField): id: AdminGraphQLField = AdminGraphQLField("id") name: AdminGraphQLField = AdminGraphQLField("name") - email: AdminGraphQLField = AdminGraphQLField("email") privileges: AdminGraphQLField = AdminGraphQLField("privileges") + email: AdminGraphQLField = AdminGraphQLField("email") created_at: AdminGraphQLField = AdminGraphQLField("createdAt") @classmethod - def metafield(cls, *, key: Optional[str] = None) -> "AdminGraphQLField": - return AdminGraphQLField("metafield", key=key) + def metafield(cls, key: str) -> "AdminGraphQLField": + return AdminGraphQLField( + "metafield", arguments={"key": {"type": "String!", "value": key}} + ) def fields(self, *subfields: AdminGraphQLField) -> "AdminFields": self._subfields.extend(subfields) @@ -29,13 +31,15 @@ def fields(self, *subfields: AdminGraphQLField) -> "AdminFields": class GuestFields(GraphQLField): id: GuestGraphQLField = GuestGraphQLField("id") name: GuestGraphQLField = GuestGraphQLField("name") - email: GuestGraphQLField = GuestGraphQLField("email") visit_count: GuestGraphQLField = GuestGraphQLField("visitCount") + email: GuestGraphQLField = GuestGraphQLField("email") created_at: GuestGraphQLField = GuestGraphQLField("createdAt") @classmethod - def metafield(cls, *, key: Optional[str] = None) -> "GuestGraphQLField": - return GuestGraphQLField("metafield", key=key) + def metafield(cls, key: str) -> "GuestGraphQLField": + return GuestGraphQLField( + "metafield", arguments={"key": {"type": "String!", "value": key}} + ) def fields(self, *subfields: GuestGraphQLField) -> "GuestFields": self._subfields.extend(subfields) @@ -48,8 +52,10 @@ class PersonInterface(GraphQLField): email: PersonGraphQLField = PersonGraphQLField("email") @classmethod - def metafield(cls, *, key: Optional[str] = None) -> "PersonGraphQLField": - return PersonGraphQLField("metafield", key=key) + def metafield(cls, key: str) -> "PersonGraphQLField": + return PersonGraphQLField( + "metafield", arguments={"key": {"type": "String!", "value": key}} + ) def fields(self, *subfields: PersonGraphQLField) -> "PersonInterface": self._subfields.extend(subfields) @@ -67,7 +73,7 @@ class PostFields(GraphQLField): @classmethod def author(cls) -> "PersonInterface": - return PersonInterface("author") + return PersonInterface("author", arguments={}) published_at: PostGraphQLField = PostGraphQLField("publishedAt") @@ -81,19 +87,20 @@ def fields( class UserFields(GraphQLField): id: UserGraphQLField = UserGraphQLField("id") name: UserGraphQLField = UserGraphQLField("name") - email: UserGraphQLField = UserGraphQLField("email") age: UserGraphQLField = UserGraphQLField("age") + email: UserGraphQLField = UserGraphQLField("email") role: UserGraphQLField = UserGraphQLField("role") + created_at: UserGraphQLField = UserGraphQLField("createdAt") @classmethod def friends(cls) -> "UserFields": - return UserFields("friends") - - created_at: UserGraphQLField = UserGraphQLField("createdAt") + return UserFields("friends", arguments={}) @classmethod - def metafield(cls, *, key: Optional[str] = None) -> "UserGraphQLField": - return UserGraphQLField("metafield", key=key) + def metafield(cls, key: str) -> "UserGraphQLField": + return UserGraphQLField( + "metafield", arguments={"key": {"type": "String!", "value": key}} + ) def fields(self, *subfields: Union[UserGraphQLField, "UserFields"]) -> "UserFields": self._subfields.extend(subfields) diff --git a/tests/main/custom_operation_builder/graphql_client/custom_mutations.py b/tests/main/custom_operation_builder/graphql_client/custom_mutations.py index 7bf5894b..de4ba602 100644 --- a/tests/main/custom_operation_builder/graphql_client/custom_mutations.py +++ b/tests/main/custom_operation_builder/graphql_client/custom_mutations.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Optional from .custom_fields import PostFields, UserFields from .input_types import AddUserInput, UpdateUserInput @@ -6,58 +6,65 @@ class Mutation: @classmethod - def add_user(cls, *, user_input: Optional[AddUserInput] = None) -> UserFields: - return UserFields(field_name="addUser", user_input=user_input) + def add_user(cls, user_input: AddUserInput) -> UserFields: + return UserFields( + field_name="addUser", + arguments={"user_input": {"type": "AddUserInput!", "value": user_input}}, + ) @classmethod - def update_user( - cls, - *, - user_id: Optional[str] = None, - user_input: Optional[UpdateUserInput] = None - ) -> UserFields: + def update_user(cls, user_id: str, user_input: UpdateUserInput) -> UserFields: return UserFields( - field_name="updateUser", user_id=user_id, user_input=user_input + field_name="updateUser", + arguments={ + "user_id": {"type": "ID!", "value": user_id}, + "user_input": {"type": "UpdateUserInput!", "value": user_input}, + }, ) @classmethod - def delete_user(cls, *, user_id: Optional[str] = None) -> UserFields: - return UserFields(field_name="deleteUser", user_id=user_id) + def delete_user(cls, user_id: str) -> UserFields: + return UserFields( + field_name="deleteUser", + arguments={"user_id": {"type": "ID!", "value": user_id}}, + ) @classmethod def add_post( - cls, - *, - title: Optional[str] = None, - content: Optional[str] = None, - authorId: Optional[str] = None, - publishedAt: Optional[Any] = None + cls, title: str, content: str, author_id: str, published_at: str ) -> PostFields: return PostFields( field_name="addPost", - title=title, - content=content, - authorId=authorId, - publishedAt=publishedAt, + arguments={ + "title": {"type": "String!", "value": title}, + "content": {"type": "String!", "value": content}, + "authorId": {"type": "ID!", "value": author_id}, + "publishedAt": {"type": "String!", "value": published_at}, + }, ) @classmethod def update_post( cls, + post_id: str, *, - post_id: Optional[str] = None, title: Optional[str] = None, content: Optional[str] = None, - publishedAt: Optional[Any] = None + published_at: Optional[str] = None ) -> PostFields: return PostFields( field_name="updatePost", - post_id=post_id, - title=title, - content=content, - publishedAt=publishedAt, + arguments={ + "post_id": {"type": "ID!", "value": post_id}, + "title": {"type": "String", "value": title}, + "content": {"type": "String", "value": content}, + "publishedAt": {"type": "String", "value": published_at}, + }, ) @classmethod - def delete_post(cls, *, post_id: Optional[str] = None) -> PostFields: - return PostFields(field_name="deletePost", post_id=post_id) + def delete_post(cls, post_id: str) -> PostFields: + return PostFields( + field_name="deletePost", + arguments={"post_id": {"type": "ID!", "value": post_id}}, + ) diff --git a/tests/main/custom_operation_builder/graphql_client/custom_queries.py b/tests/main/custom_operation_builder/graphql_client/custom_queries.py index dd880ca0..42f20dce 100644 --- a/tests/main/custom_operation_builder/graphql_client/custom_queries.py +++ b/tests/main/custom_operation_builder/graphql_client/custom_queries.py @@ -1,64 +1,47 @@ from typing import Optional -from .custom_fields import ( - AdminFields, - GuestFields, - PersonInterface, - PostFields, - UserFields, -) +from .custom_fields import PersonInterface, PostFields, UserFields from .custom_typing_fields import GraphQLField, SearchResultUnion class Query: @classmethod def hello(cls) -> GraphQLField: - return GraphQLField(field_name="hello") + return GraphQLField(field_name="hello", arguments={}) @classmethod def greeting(cls, *, name: Optional[str] = None) -> GraphQLField: - return GraphQLField(field_name="greeting", name=name) + return GraphQLField( + field_name="greeting", arguments={"name": {"type": "String", "value": name}} + ) @classmethod - def user(cls, *, user_id: Optional[str] = None) -> UserFields: - return UserFields(field_name="user", user_id=user_id) + def user(cls, user_id: str) -> UserFields: + return UserFields( + field_name="user", arguments={"user_id": {"type": "ID!", "value": user_id}} + ) @classmethod def users(cls) -> UserFields: - return UserFields(field_name="users") + return UserFields(field_name="users", arguments={}) @classmethod - def admin(cls, *, admin_id: Optional[str] = None) -> AdminFields: - return AdminFields(field_name="admin", admin_id=admin_id) - - @classmethod - def admins(cls) -> AdminFields: - return AdminFields(field_name="admins") - - @classmethod - def guest(cls, *, guest_id: Optional[str] = None) -> GuestFields: - return GuestFields(field_name="guest", guest_id=guest_id) - - @classmethod - def guests(cls) -> GuestFields: - return GuestFields(field_name="guests") - - @classmethod - def search(cls, *, text: Optional[str] = None) -> SearchResultUnion: - return SearchResultUnion(field_name="search", text=text) + def search(cls, text: str) -> SearchResultUnion: + return SearchResultUnion( + field_name="search", arguments={"text": {"type": "String!", "value": text}} + ) @classmethod def posts(cls) -> PostFields: - return PostFields(field_name="posts") - - @classmethod - def post(cls, *, post_id: Optional[str] = None) -> PostFields: - return PostFields(field_name="post", post_id=post_id) + return PostFields(field_name="posts", arguments={}) @classmethod - def person(cls, *, person_id: Optional[str] = None) -> PersonInterface: - return PersonInterface(field_name="person", person_id=person_id) + def person(cls, person_id: str) -> PersonInterface: + return PersonInterface( + field_name="person", + arguments={"person_id": {"type": "ID!", "value": person_id}}, + ) @classmethod def people(cls) -> PersonInterface: - return PersonInterface(field_name="people") + return PersonInterface(field_name="people", arguments={}) diff --git a/tests/main/custom_operation_builder/graphql_client/custom_typing_fields.py b/tests/main/custom_operation_builder/graphql_client/custom_typing_fields.py index 7c78ba55..d8476b7e 100644 --- a/tests/main/custom_operation_builder/graphql_client/custom_typing_fields.py +++ b/tests/main/custom_operation_builder/graphql_client/custom_typing_fields.py @@ -5,12 +5,6 @@ class PersonGraphQLField(GraphQLField): pass -class SearchResultUnion(GraphQLField): - def on(self, type_name: str, *subfields: GraphQLField) -> "SearchResultUnion": - self._inline_fragments[type_name] = subfields - return self - - class UserGraphQLField(GraphQLField): pass @@ -25,3 +19,9 @@ class GuestGraphQLField(GraphQLField): class PostGraphQLField(GraphQLField): pass + + +class SearchResultUnion(GraphQLField): + def on(self, type_name: str, *subfields: GraphQLField) -> "SearchResultUnion": + self._inline_fragments[type_name] = subfields + return self diff --git a/tests/main/custom_operation_builder/graphql_client/enums.py b/tests/main/custom_operation_builder/graphql_client/enums.py index 45e68c85..72d8ca4b 100644 --- a/tests/main/custom_operation_builder/graphql_client/enums.py +++ b/tests/main/custom_operation_builder/graphql_client/enums.py @@ -2,6 +2,5 @@ class Role(str, Enum): - USER = "USER" ADMIN = "ADMIN" - GUEST = "GUEST" + USER = "USER" diff --git a/tests/main/custom_operation_builder/graphql_client/input_types.py b/tests/main/custom_operation_builder/graphql_client/input_types.py index 8c993ec8..7c78cfaf 100644 --- a/tests/main/custom_operation_builder/graphql_client/input_types.py +++ b/tests/main/custom_operation_builder/graphql_client/input_types.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Optional from pydantic import Field @@ -8,10 +8,10 @@ class AddUserInput(BaseModel): name: str - age: Optional[int] = None - email: Optional[str] = None - role: Optional[Role] = Role.USER - created_at: Optional[Any] = Field(alias="createdAt", default=None) + age: int + email: str + role: Role + created_at: str = Field(alias="createdAt") class UpdateUserInput(BaseModel): @@ -19,4 +19,4 @@ class UpdateUserInput(BaseModel): age: Optional[int] = None email: Optional[str] = None role: Optional[Role] = None - created_at: Optional[Any] = Field(alias="createdAt", default=None) + created_at: Optional[str] = Field(alias="createdAt", default=None) diff --git a/tests/main/custom_operation_builder/schema.graphql b/tests/main/custom_operation_builder/schema.graphql new file mode 100644 index 00000000..34206824 --- /dev/null +++ b/tests/main/custom_operation_builder/schema.graphql @@ -0,0 +1,98 @@ +enum Role { + ADMIN + USER +} + +input AddUserInput { + name: String! + age: Int! + email: String! + role: Role! + createdAt: String! +} + +input UpdateUserInput { + name: String + age: Int + email: String + role: Role + createdAt: String +} + +interface PersonInterface { + id: ID! + name: String! + email: String! + metafield(key: String!): String +} + +type User implements PersonInterface { + id: ID! + name: String! + age: Int + email: String! + role: Role! + createdAt: String + friends: [User] +} + +type Admin implements PersonInterface { + id: ID! + name: String! + privileges: [String!]! + email: String! + createdAt: String +} + +type Guest implements PersonInterface { + id: ID! + name: String! + visitCount: Int + email: String! + createdAt: String +} + +type Post { + id: ID! + title: String! + content: String! + author: PersonInterface + publishedAt: String +} + +type Query { + hello: String + greeting(name: String): String + user(user_id: ID!): User + users: [User] + search(text: String!): [SearchResult] + posts: [Post] + person(person_id: ID!): PersonInterface + people: [PersonInterface] +} + +union SearchResult = User | Admin | Guest + +type Mutation { + addUser(user_input: AddUserInput!): User + updateUser(user_id: ID!, user_input: UpdateUserInput!): User + deleteUser(user_id: ID!): User + addPost( + title: String! + content: String! + authorId: ID! + publishedAt: String! + ): Post + updatePost( + post_id: ID! + title: String + content: String + publishedAt: String + ): Post + deletePost(post_id: ID!): Post +} + +schema { + query: Query + mutation: Mutation +} diff --git a/tests/main/custom_operation_builder/test_operation_build.py b/tests/main/custom_operation_builder/test_operation_build.py index ae3c7ac4..1149f118 100644 --- a/tests/main/custom_operation_builder/test_operation_build.py +++ b/tests/main/custom_operation_builder/test_operation_build.py @@ -14,14 +14,14 @@ def test_simple_hello(): - built_query = print_ast(Query.hello().to_ast()) + built_query = print_ast(Query.hello().to_ast(0)) expected_query = "hello" assert built_query == expected_query def test_greeting_with_name(): - built_query = print_ast(Query.greeting(name="Alice").to_ast()) - expected_query = 'greeting(name: "Alice")' + built_query = print_ast(Query.greeting(name="Alice").to_ast(0)) + expected_query = "greeting(name: $0_name)" assert built_query == expected_query @@ -34,9 +34,9 @@ def test_user_by_id(): UserFields.age, UserFields.email, ) - .to_ast() + .to_ast(0) ) - expected_query = 'user(user_id: "1") {\n id\n name\n age\n email\n}' + expected_query = "user(user_id: $0_user_id) {\n id\n name\n age\n email\n}" assert built_query == expected_query @@ -49,7 +49,7 @@ def test_all_users(): UserFields.age, UserFields.email, ) - .to_ast() + .to_ast(0) ) expected_query = "users {\n id\n name\n age\n email\n}" assert built_query == expected_query @@ -69,10 +69,10 @@ def test_user_with_friends(): ), UserFields.created_at, ) - .to_ast() + .to_ast(0) ) expected_query = ( - 'user(user_id: "1") {\n' + "user(user_id: $0_user_id) {\n" " id\n" " name\n" " age\n" @@ -111,10 +111,10 @@ def test_search_example(): GuestFields.visit_count, GuestFields.created_at, ) - .to_ast() + .to_ast(0) ) expected_query = ( - 'search(text: "example") {\n' + "search(text: $0_text) {\n" " ... on User {\n" " id\n" " name\n" @@ -150,7 +150,7 @@ def test_posts_with_authors(): ), PostFields.published_at, ) - .to_ast() + .to_ast(0) ) expected_query = ( "posts {\n" @@ -174,10 +174,10 @@ def test_get_person(): .fields(PersonInterface.id, PersonInterface.name, PersonInterface.email) .on("User", UserFields.age, UserFields.role) .on("Admin", AdminFields.privileges) - .to_ast() + .to_ast(0) ) expected_query = ( - 'person(person_id: "1") {\n' + "person(person_id: $0_person_id) {\n" " id\n" " name\n" " email\n" @@ -199,7 +199,7 @@ def test_get_people(): .fields(PersonInterface.id, PersonInterface.name, PersonInterface.email) .on("User", UserFields.age, UserFields.role) .on("Admin", AdminFields.privileges) - .to_ast() + .to_ast(0) ) expected_query = ( "people {\n" @@ -219,30 +219,26 @@ def test_get_people(): def test_add_user_mutation(): - built_mutation = print_ast( - Mutation.add_user( - user_input=AddUserInput( - name="bob", - age=30, - email="bob@example.com", - role=Role.ADMIN, - createdAt="2024-06-07T00:00:00.000Z", - ) - ) - .fields( - UserFields.id, - UserFields.name, - UserFields.age, - UserFields.email, - UserFields.role, - UserFields.created_at, + mutation = Mutation.add_user( + user_input=AddUserInput( + name="bob", + age=30, + email="bob@example.com", + role=Role.ADMIN, + createdAt="2024-06-07T00:00:00.000Z", ) - .to_ast() + ).fields( + UserFields.id, + UserFields.name, + UserFields.age, + UserFields.email, + UserFields.role, + UserFields.created_at, ) + built_mutation = print_ast(mutation.to_ast(0)) expected_mutation = ( - "addUser(\n" - ' user_input: {name: "bob", age: 30, email: "bob@example.com", role: "ADMIN", ' - 'created_at: "2024-06-07T00:00:00.000Z"}\n' + "addUser(" + "user_input: $0_user_input" ") {\n" " id\n" " name\n" @@ -253,6 +249,19 @@ def test_add_user_mutation(): "}" ) assert built_mutation == expected_mutation + assert mutation.get_formatted_variables() == { + "0_user_input": { + "name": "user_input", + "type": "AddUserInput!", + "value": AddUserInput( + name="bob", + age=30, + email="bob@example.com", + role=Role.ADMIN, + created_at="2024-06-07T00:00:00.000Z", + ), + } + } def test_update_user_mutation(): @@ -275,14 +284,12 @@ def test_update_user_mutation(): UserFields.role, UserFields.created_at, ) - .to_ast() + .to_ast(0) ) expected_mutation = ( - "updateUser(\n" - ' user_id: "1"\n' - " user_input: " - '{name: "Alice Updated", age: 25, email: "alice.updated@example.com", ' - 'role: "USER", created_at: "2024-06-07T00:00:00.000Z"}\n' + "updateUser(" + "user_id: $0_user_id, " + "user_input: $0_user_input" ") {\n" " id\n" " name\n" @@ -302,9 +309,9 @@ def test_delete_user_mutation(): UserFields.id, UserFields.name, ) - .to_ast() + .to_ast(0) ) - expected_mutation = 'deleteUser(user_id: "1") {\n id\n name\n}' + expected_mutation = "deleteUser(user_id: $0_user_id) {\n id\n name\n}" assert built_mutation == expected_mutation @@ -313,8 +320,8 @@ def test_add_post_mutation(): Mutation.add_post( title="New Post", content="This is the content", - authorId="1", - publishedAt="2024-06-07T00:00:00.000Z", + author_id="1", + published_at="2024-06-07T00:00:00.000Z", ) .fields( PostFields.id, @@ -323,14 +330,14 @@ def test_add_post_mutation(): PostFields.author().fields(PersonInterface.id, PersonInterface.name), PostFields.published_at, ) - .to_ast() + .to_ast(0) ) expected_mutation = ( "addPost(\n" - ' title: "New Post"\n' - ' content: "This is the content"\n' - ' authorId: "1"\n' - ' publishedAt: "2024-06-07T00:00:00.000Z"\n' + " title: $0_title\n" + " content: $0_content\n" + " authorId: $0_authorId\n" + " publishedAt: $0_publishedAt\n" ") {\n" " id\n" " title\n" @@ -351,7 +358,7 @@ def test_update_post_mutation(): post_id="1", title="Updated Title", content="Updated Content", - publishedAt="2024-06-07T00:00:00.000Z", + published_at="2024-06-07T00:00:00.000Z", ) .fields( PostFields.id, @@ -359,14 +366,14 @@ def test_update_post_mutation(): PostFields.content, PostFields.published_at, ) - .to_ast() + .to_ast(0) ) expected_mutation = ( "updatePost(\n" - ' post_id: "1"\n' - ' title: "Updated Title"\n' - ' content: "Updated Content"\n' - ' publishedAt: "2024-06-07T00:00:00.000Z"\n' + " post_id: $0_post_id\n" + " title: $0_title\n" + " content: $0_content\n" + " publishedAt: $0_publishedAt\n" ") {\n" " id\n" " title\n" @@ -384,17 +391,17 @@ def test_delete_post_mutation(): PostFields.id, PostFields.title, ) - .to_ast() + .to_ast(0) ) - expected_mutation = 'deletePost(post_id: "1") {\n id\n title\n}' + expected_mutation = "deletePost(post_id: $0_post_id) {\n id\n title\n}" assert built_mutation == expected_mutation def test_user_specific_fields(): built_query = print_ast( - Query.user(user_id="1").fields(UserFields.id, UserFields.name).to_ast() + Query.user(user_id="1").fields(UserFields.id, UserFields.name).to_ast(0) ) - expected_query = 'user(user_id: "1") {\n id\n name\n}' + expected_query = "user(user_id: $0_user_id) {\n id\n name\n}" assert built_query == expected_query @@ -407,10 +414,10 @@ def test_user_with_friends_specific_fields(): UserFields.friends().fields(UserFields.id, UserFields.name), UserFields.created_at, ) - .to_ast() + .to_ast(0) ) expected_query = ( - 'user(user_id: "1") {\n' + "user(user_id: $0_user_id) {\n" " id\n" " name\n" " friends {\n" @@ -424,23 +431,25 @@ def test_user_with_friends_specific_fields(): def test_people_with_metadata(): - built_query = print_ast( + query = ( Query.people() .fields( PersonInterface.id, PersonInterface.name, PersonInterface.email, PersonInterface.metafield(key="bio"), + PersonInterface.metafield(key="ots"), ) .on("User", UserFields.age, UserFields.role) - .to_ast() ) + built_query = print_ast(query.to_ast(0)) expected_query = ( "people {\n" " id\n" " name\n" " email\n" - ' metafield(key: "bio")\n' + " metafield(key: $0_key)\n" + " metafield(key: $0_key_1)\n" " ... on User {\n" " age\n" " role\n" From 5334e00b76dba9a20fcabe097fa4e6c3de011072 Mon Sep 17 00:00:00 2001 From: Damian Czajkowski Date: Fri, 12 Jul 2024 09:46:33 +0200 Subject: [PATCH 06/11] Add more tests and fix arguments creation --- ariadne_codegen/client_generators/client.py | 90 +++++++++++++------ .../client_generators/constants.py | 1 + .../client_generators/custom_operation.py | 81 +++++++++++++---- .../dependencies/base_operation.py | 2 +- .../expected_client/base_operation.py | 16 ++-- .../expected_client/client.py | 17 ++-- .../expected_client/custom_mutations.py | 8 +- .../expected_client/custom_queries.py | 44 +++++---- .../graphql_client/__init__.py | 4 +- .../graphql_client/base_operation.py | 2 +- .../graphql_client/client.py | 15 ++-- .../graphql_client/custom_fields.py | 28 +++--- .../graphql_client/custom_mutations.py | 80 +++++++++-------- .../graphql_client/custom_queries.py | 65 ++++++++++---- .../graphql_client/custom_typing_fields.py | 2 +- .../test_operation_build.py | 79 +++++++++------- 16 files changed, 346 insertions(+), 188 deletions(-) diff --git a/ariadne_codegen/client_generators/client.py b/ariadne_codegen/client_generators/client.py index 2d8c610c..139f4af4 100644 --- a/ariadne_codegen/client_generators/client.py +++ b/ariadne_codegen/client_generators/client.py @@ -50,6 +50,7 @@ OPERATION_TYPE, OPTIONAL, PRINT_AST, + SELECTION_NODE, SELECTION_SET_NODE, TUPLE, TYPING_MODULE, @@ -457,23 +458,7 @@ def create_build_operation_ast_method(self): keywords=[ generate_keyword( arg="selections", - value=generate_list_comp( - elt=generate_call( - func=generate_attribute( - value=generate_name( - "field", - ), - attr="to_ast", - ), - args=[generate_name("idx")], - ), - generators=[ - generate_comp( - target="idx, field", - iter_="enumerate(fields)", - ) - ], - ), + value=generate_name("selections"), ) ], ), @@ -498,15 +483,10 @@ def create_build_operation_ast_method(self): args=[ generate_arg("self"), generate_arg( - "fields", + "selections", annotation=generate_subscript( - generate_name(TUPLE), - generate_tuple( - [ - generate_name("GraphQLField"), - generate_name("..."), - ] - ), + generate_name(LIST), + generate_name(SELECTION_NODE), ), ), generate_arg( @@ -531,6 +511,15 @@ def create_execute_custom_operation_method(self): variables_types_combined = generate_name("variables_types_combined") processed_variables_combined = generate_name("processed_variables_combined") method_body = [ + generate_assign( + targets=["selections"], + value=generate_call( + func=generate_attribute( + value=generate_name("self"), attr="_build_selection_set" + ), + args=[generate_name("fields")], + ), + ), ast.Assign( targets=[ generate_tuple( @@ -561,7 +550,7 @@ def create_execute_custom_operation_method(self): value=generate_name("self"), attr="_build_operation_ast" ), args=[ - generate_name("fields"), + generate_name("selections"), generate_name("operation_type"), generate_name("operation_name"), generate_name("variable_definitions"), @@ -627,6 +616,53 @@ def create_execute_custom_operation_method(self): ), ) + def create_build_selection_set(self): + return generate_method_definition( + name="_build_selection_set", + arguments=generate_arguments( + args=[ + generate_arg("self"), + generate_arg( + "fields", + annotation=generate_subscript( + generate_name("Tuple"), + generate_tuple( + [ + generate_name("GraphQLField"), + generate_name("..."), + ] + ), + ), + ), + ] + ), + body=[ + generate_return( + value=generate_list_comp( + elt=generate_call( + func=generate_attribute( + value=generate_name( + "field", + ), + attr="to_ast", + ), + args=[generate_name("idx")], + ), + generators=[ + generate_comp( + target="idx, field", + iter_="enumerate(fields)", + ) + ], + ), + ), + ], + return_type=generate_subscript( + generate_name(LIST), + generate_name(SELECTION_NODE), + ), + ) + def add_execute_custom_operation_method(self): self._add_import( generate_import_from( @@ -639,6 +675,7 @@ def add_execute_custom_operation_method(self): VARIABLE_DEFINITION_NODE, VARIABLE_NODE, NAMED_TYPE_NODE, + SELECTION_NODE, ], GRAPHQL_MODULE, ) @@ -654,6 +691,7 @@ def add_execute_custom_operation_method(self): self._class_def.body.append(self.create_combine_variables_method()) self._class_def.body.append(self.create_build_variable_definitions_method()) self._class_def.body.append(self.create_build_operation_ast_method()) + self._class_def.body.append(self.create_build_selection_set()) def create_custom_operation_method(self, name, operation_type): self._add_import( diff --git a/ariadne_codegen/client_generators/constants.py b/ariadne_codegen/client_generators/constants.py index a3258def..012b8637 100644 --- a/ariadne_codegen/client_generators/constants.py +++ b/ariadne_codegen/client_generators/constants.py @@ -28,6 +28,7 @@ OPERATION_DEFINITION_NODE = "OperationDefinitionNode" NAME_NODE = "NameNode" SELECTION_SET_NODE = "SelectionSetNode" +SELECTION_NODE = "SelectionNode" PRINT_AST = "print_ast" OPERATION_TYPE = "OperationType" VARIABLE_DEFINITION_NODE = "VariableDefinitionNode" diff --git a/ariadne_codegen/client_generators/custom_operation.py b/ariadne_codegen/client_generators/custom_operation.py index 0fc9d102..5c73c54e 100644 --- a/ariadne_codegen/client_generators/custom_operation.py +++ b/ariadne_codegen/client_generators/custom_operation.py @@ -13,11 +13,14 @@ ) from ..codegen import ( + generate_ann_assign, generate_annotation_name, generate_arg, generate_arguments, + generate_assign, generate_call, generate_class_def, + generate_comp, generate_constant, generate_dict, generate_import_from, @@ -26,6 +29,8 @@ generate_module, generate_name, generate_return, + generate_subscript, + generate_tuple, ) from ..exceptions import ParsingError from ..plugins.manager import PluginManager @@ -36,6 +41,7 @@ BASE_MODEL_FILE_PATH, CUSTOM_FIELDS_FILE_PATH, CUSTOM_FIELDS_TYPING_FILE_PATH, + DICT, INPUT_SCALARS_MAP, OPTIONAL, TYPING_MODULE, @@ -69,7 +75,7 @@ def __init__( self._imports: List[ast.ImportFrom] = [] self._type_imports: List[ast.ImportFrom] = [] - self._add_import(generate_import_from([OPTIONAL, ANY], TYPING_MODULE)) + self._add_import(generate_import_from([OPTIONAL, ANY, DICT], TYPING_MODULE)) self._class_def = generate_class_def(name=name, base_names=[]) @@ -114,7 +120,11 @@ def _generate_method( operation_args, final_type, ) -> ast.FunctionDef: - method_arguments, return_arguments = self._generate_arguments(operation_args) + ( + method_arguments, + return_arguments_keys, + return_arguments_values, + ) = self._generate_arguments(operation_args) return_type_name = self._get_return_type_and_from(final_type) return generate_method_definition( @@ -122,6 +132,52 @@ def _generate_method( arguments=method_arguments, return_type=generate_name(return_type_name), body=[ + generate_ann_assign( + "arguments", + generate_subscript( + generate_name(DICT), + generate_tuple( + [ + generate_name("str"), + generate_subscript( + generate_name(DICT), + generate_tuple( + [ + generate_name("str"), + generate_name(ANY), + ] + ), + ), + ] + ), + ), + generate_dict(return_arguments_keys, return_arguments_values), + ), + generate_assign( + ["cleared_arguments"], + ast.DictComp( + key=generate_name("key"), + value=generate_name("value"), + generators=[ + generate_comp( + target="key, value", + iter_="arguments.items()", + ifs=[ + ast.Compare( + left=generate_subscript( + value=generate_name("value"), + slice_=ast.Index( + value=generate_constant("value"), + ), # type: ignore + ), + ops=[ast.IsNot()], + comparators=[generate_constant(None)], + ) + ], + ) + ], + ), + ), generate_return( value=generate_call( func=generate_name(return_type_name), @@ -131,10 +187,13 @@ def _generate_method( arg="field_name", value=generate_constant(value=operation_name), ), - return_arguments, + generate_keyword( + arg="arguments", + value=generate_name("cleared_arguments"), + ), ], ) - ) + ), ], decorator_list=[generate_name("classmethod")], ) @@ -171,11 +230,8 @@ def _generate_arguments(self, operation_args): method_arguments = self._assemble_method_arguments( cls_arg, args, kw_only_args, kw_defaults ) - return_arguments = self._assemble_return_arguments( - return_arguments_keys, return_arguments_values - ) - return method_arguments, return_arguments + return method_arguments, return_arguments_keys, return_arguments_values def _accumulate_method_arguments( self, args, kw_only_args, kw_defaults, name, annotation, is_required @@ -231,15 +287,6 @@ def _assemble_method_arguments(self, cls_arg, args, kw_only_args, kw_defaults): kw_defaults=kw_defaults, ) - def _assemble_return_arguments(self, keys, values): - return generate_keyword( - arg="arguments", - value=generate_dict( - keys=keys, - values=values, - ), - ) - def _parse_graphql_type_name( self, type_, nullable: bool = True ) -> Tuple[Union[ast.Name, ast.Subscript], Optional[str]]: diff --git a/ariadne_codegen/client_generators/dependencies/base_operation.py b/ariadne_codegen/client_generators/dependencies/base_operation.py index 1082ca4d..0695b558 100644 --- a/ariadne_codegen/client_generators/dependencies/base_operation.py +++ b/ariadne_codegen/client_generators/dependencies/base_operation.py @@ -69,7 +69,7 @@ def _build_selections( def _format_variable_name( self, idx: int, var_name: str, used_names: Set[str] ) -> str: - base_name = f"{idx}_{var_name}" + base_name = f"{var_name}_{idx}" unique_name = base_name counter = 1 while unique_name in used_names: diff --git a/tests/main/clients/custom_query_builder/expected_client/base_operation.py b/tests/main/clients/custom_query_builder/expected_client/base_operation.py index 9538b34b..0695b558 100644 --- a/tests/main/clients/custom_query_builder/expected_client/base_operation.py +++ b/tests/main/clients/custom_query_builder/expected_client/base_operation.py @@ -29,7 +29,7 @@ def __init__( ) -> None: self._field_name = field_name self._variables = arguments or {} - self._formatted_variables: Dict[str, Dict[str, Any]] = {} + self.formatted_variables: Dict[str, Dict[str, Any]] = {} self._subfields: List[GraphQLField] = [] self._alias: Optional[str] = None self._inline_fragments: Dict[str, Tuple[GraphQLField, ...]] = {} @@ -69,7 +69,7 @@ def _build_selections( def _format_variable_name( self, idx: int, var_name: str, used_names: Set[str] ) -> str: - base_name = f"{idx}_{var_name}" + base_name = f"{var_name}_{idx}" unique_name = base_name counter = 1 while unique_name in used_names: @@ -79,10 +79,10 @@ def _format_variable_name( return unique_name def _collect_all_variables(self, idx: int, used_names: Set[str]) -> None: - self._formatted_variables = {} + self.formatted_variables = {} for k, v in self._variables.items(): unique_name = self._format_variable_name(idx, k, used_names) - self._formatted_variables[unique_name] = { + self.formatted_variables[unique_name] = { "name": k, "type": v["type"], "value": v["value"], @@ -94,7 +94,7 @@ def to_ast(self, idx: int, used_names: Optional[Set[str]] = None) -> FieldNode: self._collect_all_variables(idx, used_names) formatted_args = [ GraphQLArgument(v["name"], k).to_ast() - for k, v in self._formatted_variables.items() + for k, v in self.formatted_variables.items() ] return FieldNode( name=NameNode(value=self._build_field_name()), @@ -107,12 +107,12 @@ def to_ast(self, idx: int, used_names: Optional[Set[str]] = None) -> FieldNode: ) def get_formatted_variables(self) -> Dict[str, Dict[str, Any]]: - formatted_variables = self._formatted_variables + formatted_variables = self.formatted_variables for subfield in self._subfields: subfield.get_formatted_variables() - self._formatted_variables.update(subfield._formatted_variables) + self.formatted_variables.update(subfield.formatted_variables) for subfields in self._inline_fragments.values(): for subfield in subfields: subfield.get_formatted_variables() - self._formatted_variables.update(subfield._formatted_variables) + self.formatted_variables.update(subfield.formatted_variables) return formatted_variables diff --git a/tests/main/clients/custom_query_builder/expected_client/client.py b/tests/main/clients/custom_query_builder/expected_client/client.py index 6ba1aabf..1205603e 100644 --- a/tests/main/clients/custom_query_builder/expected_client/client.py +++ b/tests/main/clients/custom_query_builder/expected_client/client.py @@ -6,6 +6,7 @@ NameNode, OperationDefinitionNode, OperationType, + SelectionNode, SelectionSetNode, VariableDefinitionNode, VariableNode, @@ -24,6 +25,7 @@ class Client(AsyncBaseClient): async def execute_custom_operation( self, *fields: GraphQLField, operation_type: OperationType, operation_name: str ) -> Dict[str, Any]: + selections = self._build_selection_set(fields) variables_types_combined, processed_variables_combined = ( self._combine_variables(fields) ) @@ -31,7 +33,7 @@ async def execute_custom_operation( variables_types_combined ) operation_ast = self._build_operation_ast( - fields, operation_type, operation_name, variable_definitions + selections, operation_type, operation_name, variable_definitions ) response = await self.execute( print_ast(operation_ast), @@ -68,7 +70,7 @@ def _build_variable_definitions( def _build_operation_ast( self, - fields: Tuple[GraphQLField, ...], + selections: List[SelectionNode], operation_type: OperationType, operation_name: str, variable_definitions: List[VariableDefinitionNode], @@ -79,15 +81,16 @@ def _build_operation_ast( operation=operation_type, name=NameNode(value=operation_name), variable_definitions=variable_definitions, - selection_set=SelectionSetNode( - selections=[ - field.to_ast(idx) for idx, field in enumerate(fields) - ] - ), + selection_set=SelectionSetNode(selections=selections), ) ] ) + def _build_selection_set( + self, fields: Tuple[GraphQLField, ...] + ) -> List[SelectionNode]: + return [field.to_ast(idx) for idx, field in enumerate(fields)] + async def query(self, *fields: GraphQLField, operation_name: str) -> Dict[str, Any]: return await self.execute_custom_operation( *fields, operation_type=OperationType.QUERY, operation_name=operation_name diff --git a/tests/main/clients/custom_query_builder/expected_client/custom_mutations.py b/tests/main/clients/custom_query_builder/expected_client/custom_mutations.py index a4937a74..f4836ddb 100644 --- a/tests/main/clients/custom_query_builder/expected_client/custom_mutations.py +++ b/tests/main/clients/custom_query_builder/expected_client/custom_mutations.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any, Dict, Optional from .custom_fields import UpdateMetadataFields @@ -6,6 +6,10 @@ class Mutation: @classmethod def update_metadata(cls, id: str) -> UpdateMetadataFields: + arguments: Dict[str, Dict[str, Any]] = {"id": {"type": "ID!", "value": id}} + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } return UpdateMetadataFields( - field_name="updateMetadata", arguments={"id": {"type": "ID!", "value": id}} + field_name="updateMetadata", arguments=cleared_arguments ) diff --git a/tests/main/clients/custom_query_builder/expected_client/custom_queries.py b/tests/main/clients/custom_query_builder/expected_client/custom_queries.py index f28580fb..90bfaf4f 100644 --- a/tests/main/clients/custom_query_builder/expected_client/custom_queries.py +++ b/tests/main/clients/custom_query_builder/expected_client/custom_queries.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any, Dict, Optional from .custom_fields import ( AppFields, @@ -13,22 +13,33 @@ class Query: def products( cls, *, channel: Optional[str] = None, first: Optional[int] = None ) -> ProductCountableConnectionFields: + arguments: Dict[str, Dict[str, Any]] = { + "channel": {"type": "String", "value": channel}, + "first": {"type": "Int", "value": first}, + } + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } return ProductCountableConnectionFields( - field_name="products", - arguments={ - "channel": {"type": "String", "value": channel}, - "first": {"type": "Int", "value": first}, - }, + field_name="products", arguments=cleared_arguments ) @classmethod def app(cls) -> AppFields: - return AppFields(field_name="app", arguments={}) + arguments: Dict[str, Dict[str, Any]] = {} + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return AppFields(field_name="app", arguments=cleared_arguments) @classmethod def product_types(cls) -> ProductTypeCountableConnectionFields: + arguments: Dict[str, Dict[str, Any]] = {} + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } return ProductTypeCountableConnectionFields( - field_name="productTypes", arguments={} + field_name="productTypes", arguments=cleared_arguments ) @classmethod @@ -40,12 +51,15 @@ def translations( first: Optional[int] = None, last: Optional[int] = None ) -> TranslatableItemConnectionFields: + arguments: Dict[str, Dict[str, Any]] = { + "before": {"type": "String", "value": before}, + "after": {"type": "String", "value": after}, + "first": {"type": "Int", "value": first}, + "last": {"type": "Int", "value": last}, + } + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } return TranslatableItemConnectionFields( - field_name="translations", - arguments={ - "before": {"type": "String", "value": before}, - "after": {"type": "String", "value": after}, - "first": {"type": "Int", "value": first}, - "last": {"type": "Int", "value": last}, - }, + field_name="translations", arguments=cleared_arguments ) diff --git a/tests/main/custom_operation_builder/graphql_client/__init__.py b/tests/main/custom_operation_builder/graphql_client/__init__.py index ac2385e2..41117541 100644 --- a/tests/main/custom_operation_builder/graphql_client/__init__.py +++ b/tests/main/custom_operation_builder/graphql_client/__init__.py @@ -4,7 +4,7 @@ from .custom_typing_fields import ( AdminGraphQLField, GuestGraphQLField, - PersonGraphQLField, + PersonInterfaceGraphQLField, PostGraphQLField, SearchResultUnion, UserGraphQLField, @@ -31,7 +31,7 @@ "GraphQLClientHttpError", "GraphQLClientInvalidResponseError", "GuestGraphQLField", - "PersonGraphQLField", + "PersonInterfaceGraphQLField", "PostGraphQLField", "Role", "SearchResultUnion", diff --git a/tests/main/custom_operation_builder/graphql_client/base_operation.py b/tests/main/custom_operation_builder/graphql_client/base_operation.py index 1082ca4d..0695b558 100644 --- a/tests/main/custom_operation_builder/graphql_client/base_operation.py +++ b/tests/main/custom_operation_builder/graphql_client/base_operation.py @@ -69,7 +69,7 @@ def _build_selections( def _format_variable_name( self, idx: int, var_name: str, used_names: Set[str] ) -> str: - base_name = f"{idx}_{var_name}" + base_name = f"{var_name}_{idx}" unique_name = base_name counter = 1 while unique_name in used_names: diff --git a/tests/main/custom_operation_builder/graphql_client/client.py b/tests/main/custom_operation_builder/graphql_client/client.py index 6ba1aabf..141ece81 100644 --- a/tests/main/custom_operation_builder/graphql_client/client.py +++ b/tests/main/custom_operation_builder/graphql_client/client.py @@ -6,6 +6,7 @@ NameNode, OperationDefinitionNode, OperationType, + SelectionNode, SelectionSetNode, VariableDefinitionNode, VariableNode, @@ -24,6 +25,7 @@ class Client(AsyncBaseClient): async def execute_custom_operation( self, *fields: GraphQLField, operation_type: OperationType, operation_name: str ) -> Dict[str, Any]: + selections = self._build_selection_set(fields) variables_types_combined, processed_variables_combined = ( self._combine_variables(fields) ) @@ -31,7 +33,7 @@ async def execute_custom_operation( variables_types_combined ) operation_ast = self._build_operation_ast( - fields, operation_type, operation_name, variable_definitions + selections, operation_type, operation_name, variable_definitions ) response = await self.execute( print_ast(operation_ast), @@ -68,7 +70,7 @@ def _build_variable_definitions( def _build_operation_ast( self, - fields: Tuple[GraphQLField, ...], + selections: List[SelectionNode], operation_type: OperationType, operation_name: str, variable_definitions: List[VariableDefinitionNode], @@ -79,15 +81,14 @@ def _build_operation_ast( operation=operation_type, name=NameNode(value=operation_name), variable_definitions=variable_definitions, - selection_set=SelectionSetNode( - selections=[ - field.to_ast(idx) for idx, field in enumerate(fields) - ] - ), + selection_set=SelectionSetNode(selections=selections), ) ] ) + def _build_selection_set(self, fields: List[GraphQLField]) -> List[SelectionNode]: + return [field.to_ast(idx) for idx, field in enumerate(fields)] + async def query(self, *fields: GraphQLField, operation_name: str) -> Dict[str, Any]: return await self.execute_custom_operation( *fields, operation_type=OperationType.QUERY, operation_name=operation_name diff --git a/tests/main/custom_operation_builder/graphql_client/custom_fields.py b/tests/main/custom_operation_builder/graphql_client/custom_fields.py index e70a1c61..2c0856fe 100644 --- a/tests/main/custom_operation_builder/graphql_client/custom_fields.py +++ b/tests/main/custom_operation_builder/graphql_client/custom_fields.py @@ -3,7 +3,7 @@ from . import ( AdminGraphQLField, GuestGraphQLField, - PersonGraphQLField, + PersonInterfaceGraphQLField, PostGraphQLField, UserGraphQLField, ) @@ -46,22 +46,26 @@ def fields(self, *subfields: GuestGraphQLField) -> "GuestFields": return self -class PersonInterface(GraphQLField): - id: PersonGraphQLField = PersonGraphQLField("id") - name: PersonGraphQLField = PersonGraphQLField("name") - email: PersonGraphQLField = PersonGraphQLField("email") +class PersonInterfaceInterface(GraphQLField): + id: PersonInterfaceGraphQLField = PersonInterfaceGraphQLField("id") + name: PersonInterfaceGraphQLField = PersonInterfaceGraphQLField("name") + email: PersonInterfaceGraphQLField = PersonInterfaceGraphQLField("email") @classmethod - def metafield(cls, key: str) -> "PersonGraphQLField": - return PersonGraphQLField( + def metafield(cls, key: str) -> "PersonInterfaceGraphQLField": + return PersonInterfaceGraphQLField( "metafield", arguments={"key": {"type": "String!", "value": key}} ) - def fields(self, *subfields: PersonGraphQLField) -> "PersonInterface": + def fields( + self, *subfields: PersonInterfaceGraphQLField + ) -> "PersonInterfaceInterface": self._subfields.extend(subfields) return self - def on(self, type_name: str, *subfields: GraphQLField) -> "PersonInterface": + def on( + self, type_name: str, *subfields: GraphQLField + ) -> "PersonInterfaceInterface": self._inline_fragments[type_name] = subfields return self @@ -72,13 +76,13 @@ class PostFields(GraphQLField): content: PostGraphQLField = PostGraphQLField("content") @classmethod - def author(cls) -> "PersonInterface": - return PersonInterface("author", arguments={}) + def author(cls) -> "PersonInterfaceInterface": + return PersonInterfaceInterface("author", arguments={}) published_at: PostGraphQLField = PostGraphQLField("publishedAt") def fields( - self, *subfields: Union[PostGraphQLField, "PersonInterface"] + self, *subfields: Union[PostGraphQLField, "PersonInterfaceInterface"] ) -> "PostFields": self._subfields.extend(subfields) return self diff --git a/tests/main/custom_operation_builder/graphql_client/custom_mutations.py b/tests/main/custom_operation_builder/graphql_client/custom_mutations.py index de4ba602..b0d0265d 100644 --- a/tests/main/custom_operation_builder/graphql_client/custom_mutations.py +++ b/tests/main/custom_operation_builder/graphql_client/custom_mutations.py @@ -7,41 +7,45 @@ class Mutation: @classmethod def add_user(cls, user_input: AddUserInput) -> UserFields: - return UserFields( - field_name="addUser", - arguments={"user_input": {"type": "AddUserInput!", "value": user_input}}, - ) + arguments = {"user_input": {"type": "AddUserInput!", "value": user_input}} + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return UserFields(field_name="addUser", arguments=cleared_arguments) @classmethod def update_user(cls, user_id: str, user_input: UpdateUserInput) -> UserFields: - return UserFields( - field_name="updateUser", - arguments={ - "user_id": {"type": "ID!", "value": user_id}, - "user_input": {"type": "UpdateUserInput!", "value": user_input}, - }, - ) + arguments = { + "user_id": {"type": "ID!", "value": user_id}, + "user_input": {"type": "UpdateUserInput!", "value": user_input}, + } + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return UserFields(field_name="updateUser", arguments=cleared_arguments) @classmethod def delete_user(cls, user_id: str) -> UserFields: - return UserFields( - field_name="deleteUser", - arguments={"user_id": {"type": "ID!", "value": user_id}}, - ) + arguments = {"user_id": {"type": "ID!", "value": user_id}} + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return UserFields(field_name="deleteUser", arguments=cleared_arguments) @classmethod def add_post( cls, title: str, content: str, author_id: str, published_at: str ) -> PostFields: - return PostFields( - field_name="addPost", - arguments={ - "title": {"type": "String!", "value": title}, - "content": {"type": "String!", "value": content}, - "authorId": {"type": "ID!", "value": author_id}, - "publishedAt": {"type": "String!", "value": published_at}, - }, - ) + arguments = { + "title": {"type": "String!", "value": title}, + "content": {"type": "String!", "value": content}, + "authorId": {"type": "ID!", "value": author_id}, + "publishedAt": {"type": "String!", "value": published_at}, + } + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return PostFields(field_name="addPost", arguments=cleared_arguments) @classmethod def update_post( @@ -52,19 +56,21 @@ def update_post( content: Optional[str] = None, published_at: Optional[str] = None ) -> PostFields: - return PostFields( - field_name="updatePost", - arguments={ - "post_id": {"type": "ID!", "value": post_id}, - "title": {"type": "String", "value": title}, - "content": {"type": "String", "value": content}, - "publishedAt": {"type": "String", "value": published_at}, - }, - ) + arguments = { + "post_id": {"type": "ID!", "value": post_id}, + "title": {"type": "String", "value": title}, + "content": {"type": "String", "value": content}, + "publishedAt": {"type": "String", "value": published_at}, + } + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return PostFields(field_name="updatePost", arguments=cleared_arguments) @classmethod def delete_post(cls, post_id: str) -> PostFields: - return PostFields( - field_name="deletePost", - arguments={"post_id": {"type": "ID!", "value": post_id}}, - ) + arguments = {"post_id": {"type": "ID!", "value": post_id}} + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return PostFields(field_name="deletePost", arguments=cleared_arguments) diff --git a/tests/main/custom_operation_builder/graphql_client/custom_queries.py b/tests/main/custom_operation_builder/graphql_client/custom_queries.py index 42f20dce..11e522fa 100644 --- a/tests/main/custom_operation_builder/graphql_client/custom_queries.py +++ b/tests/main/custom_operation_builder/graphql_client/custom_queries.py @@ -1,47 +1,74 @@ from typing import Optional -from .custom_fields import PersonInterface, PostFields, UserFields +from .custom_fields import PersonInterfaceInterface, PostFields, UserFields from .custom_typing_fields import GraphQLField, SearchResultUnion class Query: @classmethod def hello(cls) -> GraphQLField: - return GraphQLField(field_name="hello", arguments={}) + arguments = {} + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return GraphQLField(field_name="hello", arguments=cleared_arguments) @classmethod def greeting(cls, *, name: Optional[str] = None) -> GraphQLField: - return GraphQLField( - field_name="greeting", arguments={"name": {"type": "String", "value": name}} - ) + arguments = {"name": {"type": "String", "value": name}} + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return GraphQLField(field_name="greeting", arguments=cleared_arguments) @classmethod def user(cls, user_id: str) -> UserFields: - return UserFields( - field_name="user", arguments={"user_id": {"type": "ID!", "value": user_id}} - ) + arguments = {"user_id": {"type": "ID!", "value": user_id}} + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return UserFields(field_name="user", arguments=cleared_arguments) @classmethod def users(cls) -> UserFields: - return UserFields(field_name="users", arguments={}) + arguments = {} + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return UserFields(field_name="users", arguments=cleared_arguments) @classmethod def search(cls, text: str) -> SearchResultUnion: - return SearchResultUnion( - field_name="search", arguments={"text": {"type": "String!", "value": text}} - ) + arguments = {"text": {"type": "String!", "value": text}} + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return SearchResultUnion(field_name="search", arguments=cleared_arguments) @classmethod def posts(cls) -> PostFields: - return PostFields(field_name="posts", arguments={}) + arguments = {} + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return PostFields(field_name="posts", arguments=cleared_arguments) @classmethod - def person(cls, person_id: str) -> PersonInterface: - return PersonInterface( - field_name="person", - arguments={"person_id": {"type": "ID!", "value": person_id}}, + def person(cls, person_id: str) -> PersonInterfaceInterface: + arguments = {"person_id": {"type": "ID!", "value": person_id}} + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return PersonInterfaceInterface( + field_name="person", arguments=cleared_arguments ) @classmethod - def people(cls) -> PersonInterface: - return PersonInterface(field_name="people", arguments={}) + def people(cls) -> PersonInterfaceInterface: + arguments = {} + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return PersonInterfaceInterface( + field_name="people", arguments=cleared_arguments + ) diff --git a/tests/main/custom_operation_builder/graphql_client/custom_typing_fields.py b/tests/main/custom_operation_builder/graphql_client/custom_typing_fields.py index d8476b7e..826f8f2a 100644 --- a/tests/main/custom_operation_builder/graphql_client/custom_typing_fields.py +++ b/tests/main/custom_operation_builder/graphql_client/custom_typing_fields.py @@ -1,7 +1,7 @@ from .base_operation import GraphQLField -class PersonGraphQLField(GraphQLField): +class PersonInterfaceGraphQLField(GraphQLField): pass diff --git a/tests/main/custom_operation_builder/test_operation_build.py b/tests/main/custom_operation_builder/test_operation_build.py index 1149f118..76d73740 100644 --- a/tests/main/custom_operation_builder/test_operation_build.py +++ b/tests/main/custom_operation_builder/test_operation_build.py @@ -3,7 +3,7 @@ from .graphql_client.custom_fields import ( AdminFields, GuestFields, - PersonInterface, + PersonInterfaceInterface, PostFields, UserFields, ) @@ -21,7 +21,7 @@ def test_simple_hello(): def test_greeting_with_name(): built_query = print_ast(Query.greeting(name="Alice").to_ast(0)) - expected_query = "greeting(name: $0_name)" + expected_query = "greeting(name: $name_0)" assert built_query == expected_query @@ -36,7 +36,7 @@ def test_user_by_id(): ) .to_ast(0) ) - expected_query = "user(user_id: $0_user_id) {\n id\n name\n age\n email\n}" + expected_query = "user(user_id: $user_id_0) {\n id\n name\n age\n email\n}" assert built_query == expected_query @@ -72,7 +72,7 @@ def test_user_with_friends(): .to_ast(0) ) expected_query = ( - "user(user_id: $0_user_id) {\n" + "user(user_id: $user_id_0) {\n" " id\n" " name\n" " age\n" @@ -114,7 +114,7 @@ def test_search_example(): .to_ast(0) ) expected_query = ( - "search(text: $0_text) {\n" + "search(text: $text_0) {\n" " ... on User {\n" " id\n" " name\n" @@ -146,7 +146,9 @@ def test_posts_with_authors(): PostFields.title, PostFields.content, PostFields.author().fields( - PersonInterface.id, PersonInterface.name, PersonInterface.email + PersonInterfaceInterface.id, + PersonInterfaceInterface.name, + PersonInterfaceInterface.email, ), PostFields.published_at, ) @@ -171,13 +173,17 @@ def test_posts_with_authors(): def test_get_person(): built_query = print_ast( Query.person(person_id="1") - .fields(PersonInterface.id, PersonInterface.name, PersonInterface.email) + .fields( + PersonInterfaceInterface.id, + PersonInterfaceInterface.name, + PersonInterfaceInterface.email, + ) .on("User", UserFields.age, UserFields.role) .on("Admin", AdminFields.privileges) .to_ast(0) ) expected_query = ( - "person(person_id: $0_person_id) {\n" + "person(person_id: $person_id_0) {\n" " id\n" " name\n" " email\n" @@ -196,7 +202,11 @@ def test_get_person(): def test_get_people(): built_query = print_ast( Query.people() - .fields(PersonInterface.id, PersonInterface.name, PersonInterface.email) + .fields( + PersonInterfaceInterface.id, + PersonInterfaceInterface.name, + PersonInterfaceInterface.email, + ) .on("User", UserFields.age, UserFields.role) .on("Admin", AdminFields.privileges) .to_ast(0) @@ -238,7 +248,7 @@ def test_add_user_mutation(): built_mutation = print_ast(mutation.to_ast(0)) expected_mutation = ( "addUser(" - "user_input: $0_user_input" + "user_input: $user_input_0" ") {\n" " id\n" " name\n" @@ -250,7 +260,7 @@ def test_add_user_mutation(): ) assert built_mutation == expected_mutation assert mutation.get_formatted_variables() == { - "0_user_input": { + "user_input_0": { "name": "user_input", "type": "AddUserInput!", "value": AddUserInput( @@ -288,8 +298,8 @@ def test_update_user_mutation(): ) expected_mutation = ( "updateUser(" - "user_id: $0_user_id, " - "user_input: $0_user_input" + "user_id: $user_id_0, " + "user_input: $user_input_0" ") {\n" " id\n" " name\n" @@ -311,7 +321,7 @@ def test_delete_user_mutation(): ) .to_ast(0) ) - expected_mutation = "deleteUser(user_id: $0_user_id) {\n id\n name\n}" + expected_mutation = "deleteUser(user_id: $user_id_0) {\n id\n name\n}" assert built_mutation == expected_mutation @@ -327,17 +337,20 @@ def test_add_post_mutation(): PostFields.id, PostFields.title, PostFields.content, - PostFields.author().fields(PersonInterface.id, PersonInterface.name), + PostFields.author().fields( + PersonInterfaceInterface.id, + PersonInterfaceInterface.name, + ), PostFields.published_at, ) .to_ast(0) ) expected_mutation = ( "addPost(\n" - " title: $0_title\n" - " content: $0_content\n" - " authorId: $0_authorId\n" - " publishedAt: $0_publishedAt\n" + " title: $title_0\n" + " content: $content_0\n" + " authorId: $authorId_0\n" + " publishedAt: $publishedAt_0\n" ") {\n" " id\n" " title\n" @@ -370,10 +383,10 @@ def test_update_post_mutation(): ) expected_mutation = ( "updatePost(\n" - " post_id: $0_post_id\n" - " title: $0_title\n" - " content: $0_content\n" - " publishedAt: $0_publishedAt\n" + " post_id: $post_id_0\n" + " title: $title_0\n" + " content: $content_0\n" + " publishedAt: $publishedAt_0\n" ") {\n" " id\n" " title\n" @@ -393,7 +406,7 @@ def test_delete_post_mutation(): ) .to_ast(0) ) - expected_mutation = "deletePost(post_id: $0_post_id) {\n id\n title\n}" + expected_mutation = "deletePost(post_id: $post_id_0) {\n id\n title\n}" assert built_mutation == expected_mutation @@ -401,7 +414,7 @@ def test_user_specific_fields(): built_query = print_ast( Query.user(user_id="1").fields(UserFields.id, UserFields.name).to_ast(0) ) - expected_query = "user(user_id: $0_user_id) {\n id\n name\n}" + expected_query = "user(user_id: $user_id_0) {\n id\n name\n}" assert built_query == expected_query @@ -417,7 +430,7 @@ def test_user_with_friends_specific_fields(): .to_ast(0) ) expected_query = ( - "user(user_id: $0_user_id) {\n" + "user(user_id: $user_id_0) {\n" " id\n" " name\n" " friends {\n" @@ -434,11 +447,11 @@ def test_people_with_metadata(): query = ( Query.people() .fields( - PersonInterface.id, - PersonInterface.name, - PersonInterface.email, - PersonInterface.metafield(key="bio"), - PersonInterface.metafield(key="ots"), + PersonInterfaceInterface.id, + PersonInterfaceInterface.name, + PersonInterfaceInterface.email, + PersonInterfaceInterface.metafield(key="bio"), + PersonInterfaceInterface.metafield(key="ots"), ) .on("User", UserFields.age, UserFields.role) ) @@ -448,8 +461,8 @@ def test_people_with_metadata(): " id\n" " name\n" " email\n" - " metafield(key: $0_key)\n" - " metafield(key: $0_key_1)\n" + " metafield(key: $key_0)\n" + " metafield(key: $key_0_1)\n" " ... on User {\n" " age\n" " role\n" From be193fd92025f73816ad8dd4e07d5bbbd23dd4e0 Mon Sep 17 00:00:00 2001 From: Damian Czajkowski Date: Fri, 12 Jul 2024 11:51:04 +0200 Subject: [PATCH 07/11] Refactor code to meet 3.10 3.9 python needs --- ariadne_codegen/client_generators/client.py | 24 ++++++----------- .../expected_client/client.py | 15 ++++++----- .../graphql_client/client.py | 19 ++++++++------ .../graphql_client/custom_fields.py | 2 +- .../graphql_client/custom_mutations.py | 20 +++++++++----- .../graphql_client/custom_queries.py | 26 ++++++++++++------- 6 files changed, 58 insertions(+), 48 deletions(-) diff --git a/ariadne_codegen/client_generators/client.py b/ariadne_codegen/client_generators/client.py index 139f4af4..84caf3dc 100644 --- a/ariadne_codegen/client_generators/client.py +++ b/ariadne_codegen/client_generators/client.py @@ -278,8 +278,9 @@ def create_combine_variables_method(self): lineno=1, ), generate_return( - value=generate_tuple( - elts=[ + value=generate_dict( + keys=[generate_constant("types"), generate_constant("values")], + values=[ generate_name("variables_types_combined"), generate_name("processed_variables_combined"), ], @@ -306,13 +307,10 @@ def create_combine_variables_method(self): ) returns = generate_subscript( - generate_name(TUPLE), + generate_name(DICT), generate_tuple( [ - generate_subscript( - generate_name(DICT), - generate_tuple([generate_name("str"), generate_name("Any")]), - ), + generate_name("str"), generate_subscript( generate_name(DICT), generate_tuple([generate_name("str"), generate_name("Any")]), @@ -508,8 +506,6 @@ def create_build_operation_ast_method(self): ) def create_execute_custom_operation_method(self): - variables_types_combined = generate_name("variables_types_combined") - processed_variables_combined = generate_name("processed_variables_combined") method_body = [ generate_assign( targets=["selections"], @@ -521,11 +517,7 @@ def create_execute_custom_operation_method(self): ), ), ast.Assign( - targets=[ - generate_tuple( - elts=[variables_types_combined, processed_variables_combined], - ) - ], + targets=[generate_name("combined_variables")], value=generate_call( func=generate_attribute( value=generate_name("self"), attr="_combine_variables" @@ -540,7 +532,7 @@ def create_execute_custom_operation_method(self): func=generate_attribute( value=generate_name("self"), attr="_build_variable_definitions" ), - args=[generate_name("variables_types_combined")], + args=[generate_name('combined_variables["types"]')], ), ), generate_assign( @@ -574,7 +566,7 @@ def create_execute_custom_operation_method(self): keywords=[ generate_keyword( arg="variables", - value=generate_name("processed_variables_combined"), + value=generate_name('combined_variables["values"]'), ), generate_keyword( arg="operation_name", diff --git a/tests/main/clients/custom_query_builder/expected_client/client.py b/tests/main/clients/custom_query_builder/expected_client/client.py index 1205603e..c89fa2c6 100644 --- a/tests/main/clients/custom_query_builder/expected_client/client.py +++ b/tests/main/clients/custom_query_builder/expected_client/client.py @@ -26,25 +26,23 @@ async def execute_custom_operation( self, *fields: GraphQLField, operation_type: OperationType, operation_name: str ) -> Dict[str, Any]: selections = self._build_selection_set(fields) - variables_types_combined, processed_variables_combined = ( - self._combine_variables(fields) - ) + combined_variables = self._combine_variables(fields) variable_definitions = self._build_variable_definitions( - variables_types_combined + combined_variables["types"] ) operation_ast = self._build_operation_ast( selections, operation_type, operation_name, variable_definitions ) response = await self.execute( print_ast(operation_ast), - variables=processed_variables_combined, + variables=combined_variables["values"], operation_name=operation_name, ) return self.get_data(response) def _combine_variables( self, fields: Tuple[GraphQLField, ...] - ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + ) -> Dict[str, Dict[str, Any]]: variables_types_combined = {} processed_variables_combined = {} for field in fields: @@ -55,7 +53,10 @@ def _combine_variables( processed_variables_combined.update( {k: v["value"] for k, v in formatted_variables.items()} ) - return (variables_types_combined, processed_variables_combined) + return { + "types": variables_types_combined, + "values": processed_variables_combined, + } def _build_variable_definitions( self, variables_types_combined: Dict[str, str] diff --git a/tests/main/custom_operation_builder/graphql_client/client.py b/tests/main/custom_operation_builder/graphql_client/client.py index 141ece81..c89fa2c6 100644 --- a/tests/main/custom_operation_builder/graphql_client/client.py +++ b/tests/main/custom_operation_builder/graphql_client/client.py @@ -26,25 +26,23 @@ async def execute_custom_operation( self, *fields: GraphQLField, operation_type: OperationType, operation_name: str ) -> Dict[str, Any]: selections = self._build_selection_set(fields) - variables_types_combined, processed_variables_combined = ( - self._combine_variables(fields) - ) + combined_variables = self._combine_variables(fields) variable_definitions = self._build_variable_definitions( - variables_types_combined + combined_variables["types"] ) operation_ast = self._build_operation_ast( selections, operation_type, operation_name, variable_definitions ) response = await self.execute( print_ast(operation_ast), - variables=processed_variables_combined, + variables=combined_variables["values"], operation_name=operation_name, ) return self.get_data(response) def _combine_variables( self, fields: Tuple[GraphQLField, ...] - ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + ) -> Dict[str, Dict[str, Any]]: variables_types_combined = {} processed_variables_combined = {} for field in fields: @@ -55,7 +53,10 @@ def _combine_variables( processed_variables_combined.update( {k: v["value"] for k, v in formatted_variables.items()} ) - return (variables_types_combined, processed_variables_combined) + return { + "types": variables_types_combined, + "values": processed_variables_combined, + } def _build_variable_definitions( self, variables_types_combined: Dict[str, str] @@ -86,7 +87,9 @@ def _build_operation_ast( ] ) - def _build_selection_set(self, fields: List[GraphQLField]) -> List[SelectionNode]: + def _build_selection_set( + self, fields: Tuple[GraphQLField, ...] + ) -> List[SelectionNode]: return [field.to_ast(idx) for idx, field in enumerate(fields)] async def query(self, *fields: GraphQLField, operation_name: str) -> Dict[str, Any]: diff --git a/tests/main/custom_operation_builder/graphql_client/custom_fields.py b/tests/main/custom_operation_builder/graphql_client/custom_fields.py index 2c0856fe..0415824c 100644 --- a/tests/main/custom_operation_builder/graphql_client/custom_fields.py +++ b/tests/main/custom_operation_builder/graphql_client/custom_fields.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Any, Optional, Union from . import ( AdminGraphQLField, diff --git a/tests/main/custom_operation_builder/graphql_client/custom_mutations.py b/tests/main/custom_operation_builder/graphql_client/custom_mutations.py index b0d0265d..6b271253 100644 --- a/tests/main/custom_operation_builder/graphql_client/custom_mutations.py +++ b/tests/main/custom_operation_builder/graphql_client/custom_mutations.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Dict, Optional from .custom_fields import PostFields, UserFields from .input_types import AddUserInput, UpdateUserInput @@ -7,7 +7,9 @@ class Mutation: @classmethod def add_user(cls, user_input: AddUserInput) -> UserFields: - arguments = {"user_input": {"type": "AddUserInput!", "value": user_input}} + arguments: Dict[str, Dict[str, Any]] = { + "user_input": {"type": "AddUserInput!", "value": user_input} + } cleared_arguments = { key: value for key, value in arguments.items() if value["value"] is not None } @@ -15,7 +17,7 @@ def add_user(cls, user_input: AddUserInput) -> UserFields: @classmethod def update_user(cls, user_id: str, user_input: UpdateUserInput) -> UserFields: - arguments = { + arguments: Dict[str, Dict[str, Any]] = { "user_id": {"type": "ID!", "value": user_id}, "user_input": {"type": "UpdateUserInput!", "value": user_input}, } @@ -26,7 +28,9 @@ def update_user(cls, user_id: str, user_input: UpdateUserInput) -> UserFields: @classmethod def delete_user(cls, user_id: str) -> UserFields: - arguments = {"user_id": {"type": "ID!", "value": user_id}} + arguments: Dict[str, Dict[str, Any]] = { + "user_id": {"type": "ID!", "value": user_id} + } cleared_arguments = { key: value for key, value in arguments.items() if value["value"] is not None } @@ -36,7 +40,7 @@ def delete_user(cls, user_id: str) -> UserFields: def add_post( cls, title: str, content: str, author_id: str, published_at: str ) -> PostFields: - arguments = { + arguments: Dict[str, Dict[str, Any]] = { "title": {"type": "String!", "value": title}, "content": {"type": "String!", "value": content}, "authorId": {"type": "ID!", "value": author_id}, @@ -56,7 +60,7 @@ def update_post( content: Optional[str] = None, published_at: Optional[str] = None ) -> PostFields: - arguments = { + arguments: Dict[str, Dict[str, Any]] = { "post_id": {"type": "ID!", "value": post_id}, "title": {"type": "String", "value": title}, "content": {"type": "String", "value": content}, @@ -69,7 +73,9 @@ def update_post( @classmethod def delete_post(cls, post_id: str) -> PostFields: - arguments = {"post_id": {"type": "ID!", "value": post_id}} + arguments: Dict[str, Dict[str, Any]] = { + "post_id": {"type": "ID!", "value": post_id} + } cleared_arguments = { key: value for key, value in arguments.items() if value["value"] is not None } diff --git a/tests/main/custom_operation_builder/graphql_client/custom_queries.py b/tests/main/custom_operation_builder/graphql_client/custom_queries.py index 11e522fa..71ee4d8a 100644 --- a/tests/main/custom_operation_builder/graphql_client/custom_queries.py +++ b/tests/main/custom_operation_builder/graphql_client/custom_queries.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Dict, Optional from .custom_fields import PersonInterfaceInterface, PostFields, UserFields from .custom_typing_fields import GraphQLField, SearchResultUnion @@ -7,7 +7,7 @@ class Query: @classmethod def hello(cls) -> GraphQLField: - arguments = {} + arguments: Dict[str, Dict[str, Any]] = {} cleared_arguments = { key: value for key, value in arguments.items() if value["value"] is not None } @@ -15,7 +15,9 @@ def hello(cls) -> GraphQLField: @classmethod def greeting(cls, *, name: Optional[str] = None) -> GraphQLField: - arguments = {"name": {"type": "String", "value": name}} + arguments: Dict[str, Dict[str, Any]] = { + "name": {"type": "String", "value": name} + } cleared_arguments = { key: value for key, value in arguments.items() if value["value"] is not None } @@ -23,7 +25,9 @@ def greeting(cls, *, name: Optional[str] = None) -> GraphQLField: @classmethod def user(cls, user_id: str) -> UserFields: - arguments = {"user_id": {"type": "ID!", "value": user_id}} + arguments: Dict[str, Dict[str, Any]] = { + "user_id": {"type": "ID!", "value": user_id} + } cleared_arguments = { key: value for key, value in arguments.items() if value["value"] is not None } @@ -31,7 +35,7 @@ def user(cls, user_id: str) -> UserFields: @classmethod def users(cls) -> UserFields: - arguments = {} + arguments: Dict[str, Dict[str, Any]] = {} cleared_arguments = { key: value for key, value in arguments.items() if value["value"] is not None } @@ -39,7 +43,9 @@ def users(cls) -> UserFields: @classmethod def search(cls, text: str) -> SearchResultUnion: - arguments = {"text": {"type": "String!", "value": text}} + arguments: Dict[str, Dict[str, Any]] = { + "text": {"type": "String!", "value": text} + } cleared_arguments = { key: value for key, value in arguments.items() if value["value"] is not None } @@ -47,7 +53,7 @@ def search(cls, text: str) -> SearchResultUnion: @classmethod def posts(cls) -> PostFields: - arguments = {} + arguments: Dict[str, Dict[str, Any]] = {} cleared_arguments = { key: value for key, value in arguments.items() if value["value"] is not None } @@ -55,7 +61,9 @@ def posts(cls) -> PostFields: @classmethod def person(cls, person_id: str) -> PersonInterfaceInterface: - arguments = {"person_id": {"type": "ID!", "value": person_id}} + arguments: Dict[str, Dict[str, Any]] = { + "person_id": {"type": "ID!", "value": person_id} + } cleared_arguments = { key: value for key, value in arguments.items() if value["value"] is not None } @@ -65,7 +73,7 @@ def person(cls, person_id: str) -> PersonInterfaceInterface: @classmethod def people(cls) -> PersonInterfaceInterface: - arguments = {} + arguments: Dict[str, Dict[str, Any]] = {} cleared_arguments = { key: value for key, value in arguments.items() if value["value"] is not None } From 7b63767f6f3cb399ce8aa0141c98f93ad7dc2059 Mon Sep 17 00:00:00 2001 From: Damian Czajkowski Date: Fri, 12 Jul 2024 11:52:50 +0200 Subject: [PATCH 08/11] Fix pylint --- .../custom_operation_builder/graphql_client/custom_fields.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/main/custom_operation_builder/graphql_client/custom_fields.py b/tests/main/custom_operation_builder/graphql_client/custom_fields.py index 0415824c..2c0856fe 100644 --- a/tests/main/custom_operation_builder/graphql_client/custom_fields.py +++ b/tests/main/custom_operation_builder/graphql_client/custom_fields.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union +from typing import Union from . import ( AdminGraphQLField, From 9fdc6eb3e3496e8ed4fe85a1d46686454fc48bca Mon Sep 17 00:00:00 2001 From: Damian Czajkowski Date: Tue, 16 Jul 2024 18:08:35 +0200 Subject: [PATCH 09/11] Refactor in code structure --- .../client_generators/constants.py | 6 + .../client_generators/custom_arguments.py | 277 +++++++++++ .../client_generators/custom_fields.py | 451 ++++++++---------- .../client_generators/custom_fields_typing.py | 84 ++-- .../{utils.py => custom_generator_utils.py} | 0 .../client_generators/custom_operation.py | 264 ++-------- .../dependencies/base_operation.py | 68 ++- ariadne_codegen/client_generators/package.py | 5 - .../expected_client/__init__.py | 32 -- .../expected_client/base_operation.py | 68 ++- .../expected_client/custom_fields.py | 146 +++--- .../expected_client/custom_queries.py | 14 +- .../graphql_client/__init__.py | 14 - .../graphql_client/base_operation.py | 68 ++- .../graphql_client/custom_fields.py | 109 +++-- .../graphql_client/custom_queries.py | 26 +- .../custom_operation_builder/schema.graphql | 4 + .../test_operation_build.py | 385 +++++++++------ 18 files changed, 1150 insertions(+), 871 deletions(-) create mode 100644 ariadne_codegen/client_generators/custom_arguments.py rename ariadne_codegen/client_generators/{utils.py => custom_generator_utils.py} (100%) diff --git a/ariadne_codegen/client_generators/constants.py b/ariadne_codegen/client_generators/constants.py index 012b8637..e7339429 100644 --- a/ariadne_codegen/client_generators/constants.py +++ b/ariadne_codegen/client_generators/constants.py @@ -125,3 +125,9 @@ SCALARS_SERIALIZE_DICT_NAME = "SCALARS_SERIALIZE_FUNCTIONS" OPERATION_TYPES = ("Query", "Mutation", "Subscription") + +GRAPHQL_OBJECT_SUFFIX = "Fields" +GRAPHQL_INTERFACE_SUFFIX = "Interface" +GRAPHQL_FIELD_SUFFIX = "GraphQLField" +GRAPHQL_UNION_SUFFIX = "Union" +GRAPHQL_BASE_FIELD_CLASS = "GraphQLField" diff --git a/ariadne_codegen/client_generators/custom_arguments.py b/ariadne_codegen/client_generators/custom_arguments.py new file mode 100644 index 00000000..cd0700e4 --- /dev/null +++ b/ariadne_codegen/client_generators/custom_arguments.py @@ -0,0 +1,277 @@ +import ast +from typing import Any, Dict, List, Optional, Tuple, Union, cast + +from graphql import ( + GraphQLEnumType, + GraphQLInputObjectType, + GraphQLInterfaceType, + GraphQLNonNull, + GraphQLObjectType, + GraphQLScalarType, + GraphQLUnionType, +) + +from ..codegen import ( + generate_ann_assign, + generate_annotation_name, + generate_arg, + generate_arguments, + generate_assign, + generate_call, + generate_comp, + generate_constant, + generate_dict, + generate_import_from, + generate_keyword, + generate_name, + generate_subscript, + generate_tuple, +) +from ..exceptions import ParsingError +from ..plugins.manager import PluginManager +from ..utils import process_name +from .constants import ( + ANY, + BASE_MODEL_FILE_PATH, + DICT, + INPUT_SCALARS_MAP, + UPLOAD_CLASS_NAME, +) +from .custom_generator_utils import get_final_type +from .scalars import ScalarData, generate_scalar_imports + + +class ArgumentGenerator: + """Generates method arguments for GraphQL fields.""" + + def __init__( + self, + custom_scalars: Dict[str, ScalarData], + convert_to_snake_case: bool, + plugin_manager: Optional[PluginManager] = None, + ) -> None: + self.custom_scalars = custom_scalars + self.convert_to_snake_case = convert_to_snake_case + self.plugin_manager = plugin_manager + self.imports: List[ast.ImportFrom] = [] + self._used_custom_scalars: List[str] = [] + + def _add_import(self, import_: Optional[ast.ImportFrom] = None) -> None: + """Adds an import statement to the list of imports.""" + if import_: + if self.plugin_manager: + import_ = self.plugin_manager.generate_client_import(import_) + if import_.names: + self.imports.append(import_) + + def generate_arguments( + self, operation_args: Dict[str, Any] + ) -> Tuple[ast.arguments, List[ast.expr], List[ast.expr]]: + """Generates method arguments from operation arguments.""" + cls_arg = generate_arg(name="cls") + args: List[ast.arg] = [] + kw_only_args: List[ast.arg] = [] + kw_defaults: List[ast.expr] = [] + return_arguments_keys: List[ast.expr] = [] + return_arguments_values: List[ast.expr] = [] + + for arg_name, arg_value in operation_args.items(): + final_type = get_final_type(arg_value) + is_required = isinstance(arg_value.type, GraphQLNonNull) + name = process_name( + arg_name, convert_to_snake_case=self.convert_to_snake_case + ) + annotation, used_custom_scalar = self._parse_graphql_type_name( + final_type, not is_required + ) + + self._accumulate_method_arguments( + args, kw_only_args, kw_defaults, name, annotation, is_required + ) + self._accumulate_return_arguments( + return_arguments_keys, + return_arguments_values, + arg_name, + name, + final_type, + is_required, + used_custom_scalar, + ) + + method_arguments = self._assemble_method_arguments( + cls_arg, args, kw_only_args, kw_defaults + ) + return method_arguments, return_arguments_keys, return_arguments_values + + def _accumulate_method_arguments( + self, + args: List[ast.arg], + kw_only_args: List[ast.arg], + kw_defaults: List[ast.expr], + name: str, + annotation: Optional[Union[ast.Name, ast.Subscript]], + is_required: bool, + ) -> None: + """Accumulates method arguments.""" + if is_required: + args.append(generate_arg(name=name, annotation=annotation)) + else: + kw_only_args.append(generate_arg(name=name, annotation=annotation)) + kw_defaults.append(generate_constant(value=None)) + + def _accumulate_return_arguments( + self, + return_arguments_keys: List[ast.expr], + return_arguments_values: List[ast.expr], + arg_name: str, + name: str, + final_type: Union[GraphQLObjectType, GraphQLInterfaceType, GraphQLUnionType], + is_required: bool, + used_custom_scalar: Optional[str], + ) -> None: + """Accumulates return arguments.""" + constant_value = f"{final_type.name}!" if is_required else final_type.name + return_arg_dict_value = self._generate_return_arg_value( + name, used_custom_scalar + ) + + return_arguments_keys.append(generate_constant(arg_name)) + return_arguments_values.append( + generate_dict( + keys=[generate_constant("type"), generate_constant("value")], + values=[generate_constant(constant_value), return_arg_dict_value], + ) + ) + + def _generate_return_arg_value( + self, name: str, used_custom_scalar: Optional[str] + ) -> Union[ast.Call, ast.Name]: + """Generates the return argument value.""" + return_arg_dict_value: Union[ast.Call, ast.Name] = generate_name(name) + + if used_custom_scalar: + self._used_custom_scalars.append(used_custom_scalar) + scalar_data = self.custom_scalars[used_custom_scalar] + if scalar_data.serialize_name: + return_arg_dict_value = generate_call( + func=generate_name(scalar_data.serialize_name), + args=[generate_name(name)], + ) + + return return_arg_dict_value + + def _assemble_method_arguments( + self, + cls_arg: ast.arg, + args: List[ast.arg], + kw_only_args: List[ast.arg], + kw_defaults: List[ast.expr], + ) -> ast.arguments: + """Assembles method arguments.""" + return generate_arguments( + args=[cls_arg, *args], + kwonlyargs=kw_only_args, + kw_defaults=kw_defaults, # type: ignore + ) + + def _parse_graphql_type_name( + self, + type_: Union[GraphQLScalarType, GraphQLInputObjectType, GraphQLEnumType], + nullable: bool = True, + ) -> Tuple[Union[ast.Name, ast.Subscript], Optional[str]]: + """Parses the GraphQL type name and determines if it is a custom scalar.""" + name = type_.name + used_custom_scalar = None + if isinstance(type_, GraphQLInputObjectType): + self._add_import( + generate_import_from(names=[name], from_="input_types", level=1) + ) + elif isinstance(type_, GraphQLEnumType): + self._add_import(generate_import_from(names=[name], level=1)) + elif isinstance(type_, GraphQLScalarType): + if name not in self.custom_scalars: + name = INPUT_SCALARS_MAP.get(name, ANY) + if name == UPLOAD_CLASS_NAME: + self._add_import( + generate_import_from( + names=[UPLOAD_CLASS_NAME], + from_=BASE_MODEL_FILE_PATH.stem, + level=1, + ) + ) + else: + used_custom_scalar = name + name = self.custom_scalars[name].type_name + self._used_custom_scalars.append(used_custom_scalar) + else: + raise ParsingError(f"Incorrect argument type {name}") + return generate_annotation_name(name, nullable), used_custom_scalar + + def add_custom_scalar_imports(self) -> None: + """Adds imports for custom scalars used in the schema.""" + for custom_scalar_name in self._used_custom_scalars: + scalar_data = self.custom_scalars[custom_scalar_name] + for import_ in generate_scalar_imports(scalar_data): + self._add_import(import_) + + def generate_clear_arguments_section( + self, + return_arguments_keys: List[ast.expr], + return_arguments_values: List[ast.expr], + ) -> Tuple[List[ast.stmt], List[ast.keyword]]: + arguments_body = [ + generate_ann_assign( + "arguments", + generate_subscript( + generate_name(DICT), + generate_tuple( + [ + generate_name("str"), + generate_subscript( + generate_name(DICT), + generate_tuple( + [ + generate_name("str"), + generate_name(ANY), + ] + ), + ), + ] + ), + ), + generate_dict( + return_arguments_keys, + return_arguments_values, # type: ignore + ), + ), + generate_assign( + ["cleared_arguments"], + ast.DictComp( + key=generate_name("key"), + value=generate_name("value"), + generators=[ + generate_comp( + target="key, value", + iter_="arguments.items()", + ifs=cast( + List[ast.expr], + [ + ast.Compare( + left=generate_subscript( + value=generate_name("value"), + slice_=generate_constant("value"), + ), + ops=[ast.IsNot()], + comparators=[generate_constant(None)], + ) + ], + ), + ) + ], + ), + ), + ] + arguments_keyword = [ + generate_keyword(arg="arguments", value=generate_name("cleared_arguments")) + ] + return arguments_body, arguments_keyword diff --git a/ariadne_codegen/client_generators/custom_fields.py b/ariadne_codegen/client_generators/custom_fields.py index 7b80a62a..78cbad51 100644 --- a/ariadne_codegen/client_generators/custom_fields.py +++ b/ariadne_codegen/client_generators/custom_fields.py @@ -1,37 +1,26 @@ import ast -from typing import Dict, List, Optional, Set, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Tuple, Union, cast from graphql import ( - GraphQLEnumType, - GraphQLInputObjectType, GraphQLInterfaceType, - GraphQLNonNull, + GraphQLNamedType, GraphQLObjectType, - GraphQLScalarType, GraphQLSchema, GraphQLUnionType, ) -from ariadne_codegen.client_generators.scalars import ( - ScalarData, - generate_scalar_imports, -) -from ariadne_codegen.exceptions import ParsingError -from ariadne_codegen.plugins.manager import PluginManager +from ariadne_codegen.client_generators.custom_arguments import ArgumentGenerator from ..codegen import ( generate_ann_assign, - generate_annotation_name, generate_arg, generate_arguments, generate_attribute, generate_call, generate_class_def, generate_constant, - generate_dict, generate_expr, generate_import_from, - generate_keyword, generate_method_definition, generate_module, generate_name, @@ -39,21 +28,28 @@ generate_subscript, generate_union_annotation, ) +from ..plugins.manager import PluginManager from ..utils import process_name from .constants import ( ANY, - BASE_MODEL_FILE_PATH, + BASE_GRAPHQL_FIELD_CLASS_NAME, BASE_OPERATION_FILE_PATH, - INPUT_SCALARS_MAP, + DICT, + GRAPHQL_BASE_FIELD_CLASS, + GRAPHQL_INTERFACE_SUFFIX, + GRAPHQL_OBJECT_SUFFIX, + GRAPHQL_UNION_SUFFIX, OPTIONAL, TYPING_MODULE, UNION, - UPLOAD_CLASS_NAME, ) -from .utils import TypeCollector, get_final_type +from .custom_generator_utils import TypeCollector, get_final_type +from .scalars import ScalarData class CustomFieldsGenerator: + """Generates custom fields for a given GraphQL schema using Python's AST module.""" + def __init__( self, schema: GraphQLSchema, @@ -65,143 +61,180 @@ def __init__( self.convert_to_snake_case = convert_to_snake_case self.plugin_manager = plugin_manager self.custom_scalars = custom_scalars if custom_scalars else {} - self._visited_types: Set[str] = set() - self._field_classes: Set[str] = set() - self._generated_modules: Dict[str, ast.Module] = {} self._imports: List[ast.ImportFrom] = [ ast.ImportFrom( module=BASE_OPERATION_FILE_PATH.stem, - names=[ast.alias("GraphQLField")], + names=[ast.alias(BASE_GRAPHQL_FIELD_CLASS_NAME)], level=1, ) ] - self._used_custom_scalars: List[str] = [] - self._add_import(generate_import_from([OPTIONAL, UNION, ANY], TYPING_MODULE)) - + self._add_import( + generate_import_from( + [OPTIONAL, UNION, ANY, DICT], + TYPING_MODULE, + ) + ) + self.argument_generator = ArgumentGenerator( + self.custom_scalars, + self.convert_to_snake_case, + self.plugin_manager, + ) self._class_defs: List[ast.ClassDef] = self._parse_object_type_definitions( TypeCollector(self.schema).collect() ) - def _add_import(self, import_: Optional[ast.ImportFrom] = None): - if not import_: - return - if self.plugin_manager: - import_ = self.plugin_manager.generate_client_import(import_) - if import_.names: - self._imports.append(import_) - def generate(self) -> ast.Module: - self._add_custom_scalar_imports() + """Generates an AST module containing the custom fields and required imports.""" + self.argument_generator.add_custom_scalar_imports() module = generate_module( - body=( - cast(List[ast.stmt], self._imports) - + cast( - List[ast.stmt], - self._class_defs, - ) - ), + body=cast(List[ast.stmt], self._imports + self._class_defs), ) - return module - def _parse_object_type_definitions(self, class_definitions): + def _add_import(self, import_: Optional[ast.ImportFrom] = None) -> None: + """Adds an import statement to the list of imports.""" + if import_: + if self.plugin_manager: + import_ = self.plugin_manager.generate_client_import(import_) + if import_.names: + self._imports.append(import_) + + def _parse_object_type_definitions( + self, type_names: List[str] + ) -> List[ast.ClassDef]: + """ + Parses object type definitions from the schema + and generates AST class definitions. + """ class_defs = [] - interface_defs = [] - for type_name in class_definitions: + + for type_name in type_names: graphql_type = self.schema.get_type(type_name) - if isinstance(graphql_type, GraphQLObjectType): - class_def = self._generate_class_def_body( - definition=graphql_type, - class_name=f"{graphql_type.name}Fields", - ) - class_defs.append(class_def) - if isinstance(graphql_type, GraphQLInterfaceType): + if isinstance(graphql_type, (GraphQLObjectType, GraphQLInterfaceType)): class_def = self._generate_class_def_body( definition=graphql_type, - class_name=f"{graphql_type.name}Interface", - ) - class_def.body.append( - self._generate_on_method(f"{graphql_type.name}Interface") + class_name=f"{graphql_type.name}{self._get_suffix(graphql_type)}", ) + if isinstance(graphql_type, GraphQLInterfaceType): + class_def.body.append( + self._generate_on_method( + f"{graphql_type.name}{GRAPHQL_INTERFACE_SUFFIX}" + ) + ) class_defs.append(class_def) - return [*interface_defs, *class_defs] + + return class_defs def _generate_class_def_body( self, definition: Union[GraphQLObjectType, GraphQLInterfaceType], class_name: str, ) -> ast.ClassDef: - base_names = ["GraphQLField"] + """ + Generates the body of a class definition for a given GraphQL object + or interface type. + """ + base_names = [GRAPHQL_BASE_FIELD_CLASS] additional_fields_typing = set() - definition_fields: Dict[str, ast.ClassDef] = dict(definition.fields.items()) - for interface in definition.interfaces: - definition_fields.update(dict(interface.fields.items())) class_def = generate_class_def(name=class_name, base_names=base_names) - - for lineno, (org_name, field) in enumerate(definition_fields.items(), start=1): + for lineno, (org_name, field) in enumerate( + self._get_combined_fields(definition).items(), start=1 + ): name = process_name( - org_name, - convert_to_snake_case=self.convert_to_snake_case, + org_name, convert_to_snake_case=self.convert_to_snake_case ) final_type = get_final_type(field) - if isinstance(final_type, GraphQLObjectType): - field_name = f"{final_type.name}Fields" - class_def.body.append( - self.generate_product_type_method( - name, field_name, getattr(field, "args") - ) - ) + field_name, method_required = self._get_field_name( + final_type, definition.name + ) + if self._is_custom_type(final_type): additional_fields_typing.add(field_name) - elif isinstance(final_type, GraphQLInterfaceType): - field_name = f"{final_type.name}Interface" - class_def.body.append( - self.generate_product_type_method( - name, field_name, getattr(field, "args") - ) + class_def.body.append( + self._generate_class_field( + name, field_name, org_name, field, method_required, lineno ) - additional_fields_typing.add(field_name) - else: - field_name = f"{definition.name}GraphQLField" - - if isinstance(final_type, GraphQLUnionType): - field_name = f"{final_type.name}Union" - additional_fields_typing.add(field_name) - if getattr(field, "args"): - class_def.body.append( - self.generate_product_type_method( - name, field_name, getattr(field, "args") - ) - ) - else: - self._add_import(generate_import_from([field_name], level=1)) - field_class_name = generate_name(field_name) - - field_implementation = generate_ann_assign( - target=name, - annotation=field_class_name, - value=generate_call( - func=field_class_name, - args=[generate_constant(org_name)], - ), - lineno=lineno, - ) - - class_def.body.append(field_implementation) + ) class_def.body.append( self._generate_fields_method( class_name, definition.name, sorted(additional_fields_typing) ) ) - return class_def + def _get_combined_fields( + self, definition: Union[GraphQLObjectType, GraphQLInterfaceType] + ) -> Dict[str, ast.ClassDef]: + """Combines fields from the definition and its interfaces.""" + fields = dict(definition.fields.items()) + for interface in getattr(definition, "interfaces", []): + fields.update(dict(interface.fields.items())) + return fields + + def _get_field_name( + self, final_type: GraphQLNamedType, definition_name: str + ) -> Tuple[str, bool]: + """ + Returns the appropriate field name suffix based on the type of GraphQL type. + """ + if isinstance(final_type, GraphQLObjectType): + return f"{final_type.name}{GRAPHQL_OBJECT_SUFFIX}", True + if isinstance(final_type, GraphQLInterfaceType): + return f"{final_type.name}{GRAPHQL_INTERFACE_SUFFIX}", True + if isinstance(final_type, GraphQLUnionType): + field_name = f"{final_type.name}{GRAPHQL_UNION_SUFFIX}" + else: + field_name = f"{definition_name}{GRAPHQL_BASE_FIELD_CLASS}" + self._add_import( + generate_import_from( + [field_name], + from_="custom_typing_fields", + level=1, + ) + ) + return field_name, False + + def _is_custom_type( + self, + final_type: Union[GraphQLObjectType, GraphQLInterfaceType, GraphQLUnionType], + ) -> bool: + """Checks if the final type is a custom type (Object, Interface, or Union).""" + return isinstance( + final_type, (GraphQLObjectType, GraphQLInterfaceType, GraphQLUnionType) + ) + + def _generate_class_field( + self, + name: str, + field_name: str, + org_name: str, + field: ast.ClassDef, + method_required: bool, + lineno: int, + ) -> ast.stmt: + """Handles the generation of field types.""" + if getattr(field, "args") or method_required: + return self.generate_product_type_method( + name, field_name, getattr(field, "args") + ) + return generate_ann_assign( + target=name, + annotation=generate_name(f'"{field_name}"'), + value=generate_call( + func=generate_name(field_name), args=[generate_constant(org_name)] + ), + lineno=lineno, + ) + def _generate_fields_method( - self, class_name: str, definition_name: str, additional_fields_typing: List + self, class_name: str, definition_name: str, additional_fields_typing: List[str] ) -> ast.FunctionDef: - field_class_name = generate_name(f"{definition_name}GraphQLField") + """Generates the `fields` method for a class.""" + field_class_name = generate_name(f"{definition_name}{GRAPHQL_BASE_FIELD_CLASS}") self._add_import( - generate_import_from([f"{definition_name}GraphQLField"], level=1) + generate_import_from( + [field_class_name.id], from_="custom_typing_fields", level=1 + ) ) fields_annotation: Union[ast.Name, ast.Subscript] = field_class_name if additional_fields_typing: @@ -222,11 +255,15 @@ def _generate_fields_method( ] ), body=[ + generate_expr( + value=generate_constant( + value=f"Subfields should come from the {class_name} class" + ) + ), generate_expr( value=generate_call( func=generate_attribute( - value=generate_name("self"), - attr="_subfields.extend", + value=generate_name("self"), attr="_subfields.extend" ), args=[generate_name("subfields")], ) @@ -236,99 +273,8 @@ def _generate_fields_method( return_type=generate_name(f'"{class_name}"'), ) - def _generate_kw_args_and_defaults(self, operation_args): - kw_only_args = [] - kw_defaults = [] - args = [] - for arg_name, arg_type in operation_args.items(): - arg_name = process_name( - arg_name, - convert_to_snake_case=self.convert_to_snake_case, - ) - arg_final_type = get_final_type(arg_type) - is_required = isinstance(arg_type.type, GraphQLNonNull) - annotation, _ = self._parse_graphql_type_name( - arg_final_type, - not is_required, - ) - arg = generate_arg(name=arg_name, annotation=annotation) - if is_required: - args.append(arg) - else: - kw_only_args.append(arg) - kw_defaults.append(generate_constant(value=None)) - return kw_only_args, kw_defaults, args - - def _get_dict_value(self, name: str, arg_value) -> Union[ast.Name, ast.Call]: - name = process_name( - name, - convert_to_snake_case=self.convert_to_snake_case, - ) - _, used_custom_scalar = self._parse_graphql_type_name(get_final_type(arg_value)) - if used_custom_scalar: - self._used_custom_scalars.append(used_custom_scalar) - scalar_data = self.custom_scalars[used_custom_scalar] - if scalar_data.serialize_name: - return generate_call( - func=generate_name(scalar_data.serialize_name), - args=[generate_name(name)], - ) - return generate_name(name) - - def _generate_arguments_dict(self, operation_args) -> Dict[ast.Constant, ast.Dict]: - arguments_dict = {} - for arg_name, arg_value in operation_args.items(): - final_type = get_final_type(arg_value) - is_required = isinstance(arg_value.type, GraphQLNonNull) - constant_value = f"{final_type.name}!" if is_required else final_type.name - arguments_dict[generate_constant(arg_name)] = generate_dict( - keys=[generate_constant("type"), generate_constant("value")], - values=[ - generate_constant(constant_value), - self._get_dict_value(arg_name, arg_value), - ], - ) - return arguments_dict - - def generate_product_type_method( - self, name, class_name, arguments=None - ) -> ast.FunctionDef: - arguments = arguments or {} - field_class_name = generate_name(class_name) - kw_only_args, kw_defaults, args = self._generate_kw_args_and_defaults( - arguments, - ) - return_arguments_dict = self._generate_arguments_dict(arguments) - - return_keyword = generate_keyword( - arg="arguments", - value=generate_dict( - keys=list(return_arguments_dict.keys()), - values=list(return_arguments_dict.values()), - ), - ) - - return generate_method_definition( - name, - arguments=generate_arguments( - args=[generate_arg(name="cls"), *args], - kwonlyargs=kw_only_args, - kw_defaults=kw_defaults, - ), - body=[ - generate_return( - value=generate_call( - func=field_class_name, - args=[generate_constant(name)], - keywords=[return_keyword], - ) - ), - ], - return_type=generate_name(f'"{class_name}"'), - decorator_list=[generate_name("classmethod")], - ) - def _generate_on_method(self, class_name: str) -> ast.FunctionDef: + """Generates the `on` method for a class.""" return generate_method_definition( "on", arguments=generate_arguments( @@ -336,13 +282,14 @@ def _generate_on_method(self, class_name: str) -> ast.FunctionDef: generate_arg(name="self"), generate_arg(name="type_name", annotation=generate_name("str")), generate_arg( - name="*subfields", annotation=generate_name("GraphQLField") + name="*subfields", + annotation=generate_name(GRAPHQL_BASE_FIELD_CLASS), ), ] ), - body=[ - cast( - ast.stmt, + body=cast( + List[ast.stmt], + [ ast.Assign( targets=[ generate_subscript( @@ -356,45 +303,61 @@ def _generate_on_method(self, class_name: str) -> ast.FunctionDef: value=generate_name("subfields"), lineno=1, ), - ), - generate_return(value=generate_name("self")), - ], + generate_return(value=generate_name("self")), + ], + ), return_type=generate_name(f'"{class_name}"'), ) - def _parse_graphql_type_name( - self, type_, nullable: bool = True - ) -> Tuple[Union[ast.Name, ast.Subscript], Optional[str]]: - name = type_.name - used_custom_scalar = None - if isinstance(type_, GraphQLInputObjectType): - self._add_import( - generate_import_from(names=[name], from_="input_types", level=1) + def generate_product_type_method( + self, name: str, class_name: str, arguments: Optional[Dict[str, Any]] = None + ) -> ast.FunctionDef: + """Generates a method for a product type.""" + arguments = arguments or {} + field_class_name = generate_name(class_name) + ( + method_arguments, + return_arguments_keys, + return_arguments_values, + ) = self.argument_generator.generate_arguments(arguments) + self._imports.extend(self.argument_generator.imports) + arguments_body: List[ast.stmt] = [] + arguments_keyword: List[ast.keyword] = [] + + if arguments: + ( + arguments_body, + arguments_keyword, + ) = self.argument_generator.generate_clear_arguments_section( + return_arguments_keys, return_arguments_values ) - elif isinstance(type_, GraphQLEnumType): - self._add_import(generate_import_from(names=[name], level=1)) - elif isinstance(type_, GraphQLScalarType): - if name not in self.custom_scalars: - name = INPUT_SCALARS_MAP.get(name, ANY) - if name == UPLOAD_CLASS_NAME: - self._add_import( - generate_import_from( - names=[UPLOAD_CLASS_NAME], - from_=BASE_MODEL_FILE_PATH.stem, - level=1, - ) - ) - else: - used_custom_scalar = name - name = self.custom_scalars[name].type_name - self._used_custom_scalars.append(used_custom_scalar) - else: - raise ParsingError(f"Incorrect argument type {name}") - return generate_annotation_name(name, nullable), used_custom_scalar + return generate_method_definition( + name, + arguments=method_arguments, + body=cast( + List[ast.stmt], + [ + *arguments_body, + generate_return( + value=generate_call( + func=field_class_name, + args=[generate_constant(name)], + keywords=arguments_keyword, + ) + ), + ], + ), + return_type=generate_name(f'"{class_name}"'), + decorator_list=[generate_name("classmethod")], + ) - def _add_custom_scalar_imports(self): - for custom_scalar_name in self._used_custom_scalars: - scalar_data = self.custom_scalars[custom_scalar_name] - for import_ in generate_scalar_imports(scalar_data): - self._add_import(import_) + def _get_suffix( + self, graphql_type: Union[GraphQLObjectType, GraphQLInterfaceType] + ) -> str: + """Gets the appropriate suffix for a GraphQL type.""" + if isinstance(graphql_type, GraphQLObjectType): + return GRAPHQL_OBJECT_SUFFIX + if isinstance(graphql_type, GraphQLInterfaceType): + return GRAPHQL_INTERFACE_SUFFIX + raise ValueError(f"Unexpected graphql_type: {graphql_type}") diff --git a/ariadne_codegen/client_generators/custom_fields_typing.py b/ariadne_codegen/client_generators/custom_fields_typing.py index d38a2fc8..8c9fd6d3 100644 --- a/ariadne_codegen/client_generators/custom_fields_typing.py +++ b/ariadne_codegen/client_generators/custom_fields_typing.py @@ -8,7 +8,7 @@ GraphQLUnionType, ) -from ariadne_codegen.client_generators.utils import get_final_type +from ariadne_codegen.client_generators.custom_generator_utils import get_final_type from ..codegen import ( generate_arg, @@ -21,18 +21,19 @@ generate_return, generate_subscript, ) -from .constants import BASE_OPERATION_FILE_PATH, OPERATION_TYPES +from .constants import ( + BASE_OPERATION_FILE_PATH, + GRAPHQL_BASE_FIELD_CLASS, + OPERATION_TYPES, +) class CustomFieldsTypingGenerator: - def __init__( - self, - schema: GraphQLSchema, - ) -> None: + def __init__(self, schema: GraphQLSchema) -> None: self.schema = schema self.graphql_field_import = ast.ImportFrom( module=BASE_OPERATION_FILE_PATH.stem, - names=[ast.alias("GraphQLField")], + names=[ast.alias(GRAPHQL_BASE_FIELD_CLASS)], level=1, ) self._public_names: List[str] = [] @@ -41,14 +42,19 @@ def __init__( ] def generate(self) -> ast.Module: + """ + Generates an AST module containing the custom fields and required imports. + """ return generate_module( - body=( - cast(List[ast.stmt], [self.graphql_field_import]) - + cast(List[ast.stmt], [self._class_defs]) - ) + body=cast(List[ast.stmt], [self.graphql_field_import]) + + cast(List[ast.stmt], self._class_defs) ) - def _filter_types(self): + def _filter_types(self) -> List[ast.ClassDef]: + """ + Filters GraphQL types to include only objects, interfaces, and unions, + excluding internal and operation types. + """ return [ get_final_type(definition) for name, definition in self.schema.type_map.items() @@ -59,22 +65,34 @@ def _filter_types(self): and name not in OPERATION_TYPES ] - def _generate_field_class(self, class_def: ast.ClassDef) -> ast.ClassDef: - class_name = f"{class_def.name}GraphQLField" + def _generate_field_class( + self, + graphql_type: ast.ClassDef, + ) -> ast.ClassDef: + """ + Generates a field class for the given GraphQL type. + """ + class_name = f"{graphql_type.name}{GRAPHQL_BASE_FIELD_CLASS}" class_body: List[ast.stmt] = [] - if isinstance(class_def, GraphQLUnionType): - class_name = f"{class_def.name}Union" + + if isinstance(graphql_type, GraphQLUnionType): + class_name = f"{graphql_type.name}Union" class_body.append(self._generate_on_method(class_name)) + if class_name not in self._public_names: self._public_names.append(class_name) + field_class_def = generate_class_def( name=class_name, - base_names=["GraphQLField"], + base_names=[GRAPHQL_BASE_FIELD_CLASS], body=class_body if class_body else cast(List[ast.stmt], [ast.Pass()]), ) return field_class_def def _generate_on_method(self, class_name: str) -> ast.FunctionDef: + """ + Generates the `on` method for a class. + """ return generate_method_definition( "on", arguments=generate_arguments( @@ -82,26 +100,23 @@ def _generate_on_method(self, class_name: str) -> ast.FunctionDef: generate_arg(name="self"), generate_arg(name="type_name", annotation=generate_name("str")), generate_arg( - name="*subfields", annotation=generate_name("GraphQLField") + name="*subfields", + annotation=generate_name(GRAPHQL_BASE_FIELD_CLASS), ), ] ), body=[ - cast( - ast.stmt, - ast.Assign( - targets=[ - generate_subscript( - value=generate_attribute( - value=generate_name("self"), - attr="_inline_fragments", - ), - slice_=generate_name("type_name"), - ) - ], - value=generate_name("subfields"), - lineno=1, - ), + ast.Assign( + targets=[ + generate_subscript( + value=generate_attribute( + value=generate_name("self"), attr="_inline_fragments" + ), + slice_=generate_name("type_name"), + ) + ], + value=generate_name("subfields"), + lineno=1, ), generate_return(value=generate_name("self")), ], @@ -109,4 +124,7 @@ def _generate_on_method(self, class_name: str) -> ast.FunctionDef: ) def get_generated_public_names(self) -> List[str]: + """ + Returns the list of generated public names. + """ return self._public_names diff --git a/ariadne_codegen/client_generators/utils.py b/ariadne_codegen/client_generators/custom_generator_utils.py similarity index 100% rename from ariadne_codegen/client_generators/utils.py rename to ariadne_codegen/client_generators/custom_generator_utils.py diff --git a/ariadne_codegen/client_generators/custom_operation.py b/ariadne_codegen/client_generators/custom_operation.py index 5c73c54e..f9ee64f6 100644 --- a/ariadne_codegen/client_generators/custom_operation.py +++ b/ariadne_codegen/client_generators/custom_operation.py @@ -1,57 +1,50 @@ import ast -from typing import Dict, List, Optional, Tuple, Union, cast +from typing import Dict, List, Optional, cast from graphql import ( - GraphQLEnumType, GraphQLFieldMap, - GraphQLInputObjectType, GraphQLInterfaceType, - GraphQLNonNull, GraphQLObjectType, - GraphQLScalarType, GraphQLUnionType, ) +from ariadne_codegen.client_generators.custom_arguments import ArgumentGenerator + from ..codegen import ( - generate_ann_assign, - generate_annotation_name, - generate_arg, - generate_arguments, - generate_assign, generate_call, generate_class_def, - generate_comp, generate_constant, - generate_dict, generate_import_from, generate_keyword, generate_method_definition, generate_module, generate_name, generate_return, - generate_subscript, - generate_tuple, ) -from ..exceptions import ParsingError from ..plugins.manager import PluginManager -from ..utils import process_name, str_to_snake_case +from ..utils import str_to_snake_case from .arguments import ArgumentsGenerator from .constants import ( ANY, - BASE_MODEL_FILE_PATH, CUSTOM_FIELDS_FILE_PATH, CUSTOM_FIELDS_TYPING_FILE_PATH, DICT, - INPUT_SCALARS_MAP, + GRAPHQL_BASE_FIELD_CLASS, + GRAPHQL_INTERFACE_SUFFIX, + GRAPHQL_OBJECT_SUFFIX, + GRAPHQL_UNION_SUFFIX, OPTIONAL, TYPING_MODULE, - UPLOAD_CLASS_NAME, ) -from .scalars import ScalarData, generate_scalar_imports -from .utils import get_final_type +from .custom_generator_utils import get_final_type +from .scalars import ScalarData class CustomOperationGenerator: + """ + Generates custom operations for a given GraphQL schema using Python's AST module. + """ + def __init__( self, graphql_fields: GraphQLFieldMap, @@ -69,18 +62,20 @@ def __init__( self.enums_module_name = enums_module_name self.plugin_manager = plugin_manager self.custom_scalars = custom_scalars if custom_scalars else {} - self._used_custom_scalars: List[str] = [] self.arguments_generator = arguments_generator self.convert_to_snake_case = convert_to_snake_case self._imports: List[ast.ImportFrom] = [] self._type_imports: List[ast.ImportFrom] = [] self._add_import(generate_import_from([OPTIONAL, ANY, DICT], TYPING_MODULE)) + self.argument_generator = ArgumentGenerator( + self.custom_scalars, + self.convert_to_snake_case, + self.plugin_manager, + ) self._class_def = generate_class_def(name=name, base_names=[]) - self._used_inputs: List[str] = [] - def generate(self) -> ast.Module: """Generate module with class definition of graphql client.""" for name, field in self.graphql_fields.items(): @@ -96,7 +91,7 @@ def generate(self) -> ast.Module: if not self._class_def.body: self._class_def.body.append(ast.Pass()) - self._add_custom_scalar_imports() + self.argument_generator.add_custom_scalar_imports() self._class_def.lineno = len(self._imports) + 3 @@ -108,6 +103,7 @@ def generate(self) -> ast.Module: return module def _add_import(self, import_: Optional[ast.ImportFrom] = None): + """Adds an import statement to the list of imports.""" if import_: if self.plugin_manager: import_ = self.plugin_manager.generate_client_import(import_) @@ -120,64 +116,33 @@ def _generate_method( operation_args, final_type, ) -> ast.FunctionDef: + """Generates a method definition for a given operation.""" ( method_arguments, return_arguments_keys, return_arguments_values, - ) = self._generate_arguments(operation_args) + ) = self.argument_generator.generate_arguments(operation_args) + self._imports.extend(self.argument_generator.imports) + return_type_name = self._get_return_type_and_from(final_type) + arguments_body: List[ast.stmt] = [] + arguments_keyword: List[ast.keyword] = [] + + if operation_args: + ( + arguments_body, + arguments_keyword, + ) = self.argument_generator.generate_clear_arguments_section( + return_arguments_keys, return_arguments_values + ) + return generate_method_definition( name=str_to_snake_case(operation_name), arguments=method_arguments, return_type=generate_name(return_type_name), body=[ - generate_ann_assign( - "arguments", - generate_subscript( - generate_name(DICT), - generate_tuple( - [ - generate_name("str"), - generate_subscript( - generate_name(DICT), - generate_tuple( - [ - generate_name("str"), - generate_name(ANY), - ] - ), - ), - ] - ), - ), - generate_dict(return_arguments_keys, return_arguments_values), - ), - generate_assign( - ["cleared_arguments"], - ast.DictComp( - key=generate_name("key"), - value=generate_name("value"), - generators=[ - generate_comp( - target="key, value", - iter_="arguments.items()", - ifs=[ - ast.Compare( - left=generate_subscript( - value=generate_name("value"), - slice_=ast.Index( - value=generate_constant("value"), - ), # type: ignore - ), - ops=[ast.IsNot()], - comparators=[generate_constant(None)], - ) - ], - ) - ], - ), - ), + *arguments_body, generate_return( value=generate_call( func=generate_name(return_type_name), @@ -187,10 +152,7 @@ def _generate_method( arg="field_name", value=generate_constant(value=operation_name), ), - generate_keyword( - arg="arguments", - value=generate_name("cleared_arguments"), - ), + *arguments_keyword, ], ) ), @@ -198,150 +160,21 @@ def _generate_method( decorator_list=[generate_name("classmethod")], ) - def _generate_arguments(self, operation_args): - cls_arg = generate_arg(name="cls") - args, kw_only_args, kw_defaults = [], [], [] - return_arguments_keys, return_arguments_values = [], [] - - for arg_name, arg_value in operation_args.items(): - final_type = get_final_type(arg_value) - is_required = isinstance(arg_value.type, GraphQLNonNull) - name = process_name( - arg_name, - convert_to_snake_case=self.convert_to_snake_case, - ) - annotation, used_custom_scalar = self._parse_graphql_type_name( - final_type, not is_required - ) - - self._accumulate_method_arguments( - args, kw_only_args, kw_defaults, name, annotation, is_required - ) - self._accumulate_return_arguments( - return_arguments_keys, - return_arguments_values, - arg_name, - name, - final_type, - is_required, - used_custom_scalar, - ) - - method_arguments = self._assemble_method_arguments( - cls_arg, args, kw_only_args, kw_defaults - ) - - return method_arguments, return_arguments_keys, return_arguments_values - - def _accumulate_method_arguments( - self, args, kw_only_args, kw_defaults, name, annotation, is_required - ): - if is_required: - args.append(generate_arg(name=name, annotation=annotation)) - else: - kw_only_args.append(generate_arg(name=name, annotation=annotation)) - kw_defaults.append(generate_constant(value=None)) - - def _accumulate_return_arguments( - self, - return_arguments_keys, - return_arguments_values, - arg_name, - name, - final_type, - is_required, - used_custom_scalar, - ): - constant_value = f"{final_type.name}!" if is_required else final_type.name - return_arg_dict_value = self._generate_return_arg_value( - name, - used_custom_scalar, - ) - - return_arguments_keys.append(generate_constant(arg_name)) - return_arguments_values.append( - generate_dict( - keys=[generate_constant("type"), generate_constant("value")], - values=[generate_constant(constant_value), return_arg_dict_value], - ) - ) - - def _generate_return_arg_value(self, name, used_custom_scalar): - return_arg_dict_value = generate_name(name) - - if used_custom_scalar: - self._used_custom_scalars.append(used_custom_scalar) - scalar_data = self.custom_scalars[used_custom_scalar] - if scalar_data.serialize_name: - return_arg_dict_value = generate_call( - func=generate_name(scalar_data.serialize_name), - args=[generate_name(name)], - ) - - return return_arg_dict_value - - def _assemble_method_arguments(self, cls_arg, args, kw_only_args, kw_defaults): - return generate_arguments( - args=[cls_arg, *args], - kwonlyargs=kw_only_args, - kw_defaults=kw_defaults, - ) - - def _parse_graphql_type_name( - self, type_, nullable: bool = True - ) -> Tuple[Union[ast.Name, ast.Subscript], Optional[str]]: - name = type_.name - - used_custom_scalar = None - if isinstance(type_, GraphQLInputObjectType): - self._used_inputs.append(name) - self._add_import( - generate_import_from( - names=[name], - from_="input_types", - level=1, - ) - ) - elif isinstance(type_, GraphQLEnumType): - self._add_import( - generate_import_from( - names=[name], - from_=self.enums_module_name, - level=1, - ) - ) - elif isinstance(type_, GraphQLScalarType): - if name not in self.custom_scalars: - name = INPUT_SCALARS_MAP.get(name, ANY) - if name == UPLOAD_CLASS_NAME: - self._add_import( - generate_import_from( - names=[UPLOAD_CLASS_NAME], - from_=BASE_MODEL_FILE_PATH.stem, - level=1, - ) - ) - else: - used_custom_scalar = name - name = self.custom_scalars[name].type_name - self._used_custom_scalars.append(used_custom_scalar) - else: - raise ParsingError(f"Incorrect argument type {name}") - - return generate_annotation_name(name, nullable), used_custom_scalar - def _get_return_type_and_from(self, final_type): + """ + Determines the return type name and its import path based on the final type. + """ if isinstance(final_type, GraphQLObjectType): - return_type_name = f"{final_type.name}Fields" + return_type_name = f"{final_type.name}{GRAPHQL_OBJECT_SUFFIX}" from_ = CUSTOM_FIELDS_FILE_PATH.stem elif isinstance(final_type, GraphQLInterfaceType): - return_type_name = f"{final_type.name}Interface" + return_type_name = f"{final_type.name}{GRAPHQL_INTERFACE_SUFFIX}" from_ = CUSTOM_FIELDS_FILE_PATH.stem elif isinstance(final_type, GraphQLUnionType): - return_type_name = f"{final_type.name}Union" + return_type_name = f"{final_type.name}{GRAPHQL_UNION_SUFFIX}" from_ = CUSTOM_FIELDS_TYPING_FILE_PATH.stem else: - return_type_name = "GraphQLField" + return_type_name = GRAPHQL_BASE_FIELD_CLASS from_ = CUSTOM_FIELDS_TYPING_FILE_PATH.stem self._type_imports.append( generate_import_from( @@ -352,12 +185,7 @@ def _get_return_type_and_from(self, final_type): ) return return_type_name - def _add_custom_scalar_imports(self): - for custom_scalar_name in self._used_custom_scalars: - scalar_data = self.custom_scalars[custom_scalar_name] - for import_ in generate_scalar_imports(scalar_data): - self._add_import(import_) - @staticmethod def _capitalize_first_letter(s: str) -> str: + """Capitalizes the first letter of the given string.""" return s[0].upper() + s[1:] diff --git a/ariadne_codegen/client_generators/dependencies/base_operation.py b/ariadne_codegen/client_generators/dependencies/base_operation.py index 0695b558..9f9c7660 100644 --- a/ariadne_codegen/client_generators/dependencies/base_operation.py +++ b/ariadne_codegen/client_generators/dependencies/base_operation.py @@ -12,11 +12,16 @@ class GraphQLArgument: - def __init__(self, argument_name: str, argument_value: Any): + """ + Represents a GraphQL argument and allows conversion to an AST structure. + """ + + def __init__(self, argument_name: str, argument_value: Any) -> None: self._name = argument_name self._value = argument_value def to_ast(self) -> ArgumentNode: + """Converts the argument to an ArgumentNode AST object.""" return ArgumentNode( name=NameNode(value=self._name), value=VariableNode(name=NameNode(value=self._value)), @@ -24,6 +29,15 @@ def to_ast(self) -> ArgumentNode: class GraphQLField: + """ + Represents a GraphQL field with its name, arguments, subfields, alias, + and inline fragments. + + Attributes: + formatted_variables (Dict[str, Dict[str, Any]]): The formatted arguments + of the GraphQL field. + """ + def __init__( self, field_name: str, arguments: Optional[Dict[str, Dict[str, Any]]] = None ) -> None: @@ -35,24 +49,25 @@ def __init__( self._inline_fragments: Dict[str, Tuple[GraphQLField, ...]] = {} def alias(self, alias: str) -> "GraphQLField": + """Sets an alias for the GraphQL field and returns the instance.""" self._alias = alias return self - def add_subfield(self, subfield: "GraphQLField") -> None: - self._subfields.append(subfield) - - def add_inline_fragment(self, type_name: str, *subfields: "GraphQLField") -> None: - self._inline_fragments[type_name] = subfields - def _build_field_name(self) -> str: + """Builds the field name, including the alias if present.""" return f"{self._alias}: {self._field_name}" if self._alias else self._field_name def _build_selections( self, idx: int, used_names: Set[str] ) -> List[Union[FieldNode, InlineFragmentNode]]: + """Builds the selection set for the current GraphQL field, + including subfields and inline fragments.""" + # Create selections from subfields selections: List[Union[FieldNode, InlineFragmentNode]] = [ subfield.to_ast(idx, used_names) for subfield in self._subfields ] + + # Add inline fragments for name, subfields in self._inline_fragments.items(): selections.append( InlineFragmentNode( @@ -64,22 +79,35 @@ def _build_selections( ), ) ) + return selections def _format_variable_name( self, idx: int, var_name: str, used_names: Set[str] ) -> str: + """Generates a unique variable name by appending an index and, + if necessary, an additional counter to avoid duplicates.""" base_name = f"{var_name}_{idx}" unique_name = base_name counter = 1 + + # Ensure the generated name is unique while unique_name in used_names: unique_name = f"{base_name}_{counter}" counter += 1 + + # Add the unique name to the set of used names used_names.add(unique_name) + return unique_name def _collect_all_variables(self, idx: int, used_names: Set[str]) -> None: + """ + Collects and formats all variables for the current GraphQL field, + ensuring unique names. + """ self.formatted_variables = {} + for k, v in self._variables.items(): unique_name = self._format_variable_name(idx, k, used_names) self.formatted_variables[unique_name] = { @@ -89,16 +117,18 @@ def _collect_all_variables(self, idx: int, used_names: Set[str]) -> None: } def to_ast(self, idx: int, used_names: Optional[Set[str]] = None) -> FieldNode: + """Converts the current GraphQL field to an AST (Abstract Syntax Tree) node.""" if used_names is None: used_names = set() + self._collect_all_variables(idx, used_names) - formatted_args = [ - GraphQLArgument(v["name"], k).to_ast() - for k, v in self.formatted_variables.items() - ] + return FieldNode( name=NameNode(value=self._build_field_name()), - arguments=formatted_args, + arguments=[ + GraphQLArgument(v["name"], k).to_ast() + for k, v in self.formatted_variables.items() + ], selection_set=( SelectionSetNode(selections=self._build_selections(idx, used_names)) if self._subfields or self._inline_fragments @@ -107,12 +137,20 @@ def to_ast(self, idx: int, used_names: Optional[Set[str]] = None) -> FieldNode: ) def get_formatted_variables(self) -> Dict[str, Dict[str, Any]]: - formatted_variables = self.formatted_variables + """ + Retrieves all formatted variables for the current GraphQL field, + including those from subfields and inline fragments. + """ + formatted_variables = self.formatted_variables.copy() + + # Collect variables from subfields for subfield in self._subfields: subfield.get_formatted_variables() - self.formatted_variables.update(subfield.formatted_variables) + formatted_variables.update(subfield.formatted_variables) + + # Collect variables from inline fragments for subfields in self._inline_fragments.values(): for subfield in subfields: subfield.get_formatted_variables() - self.formatted_variables.update(subfield.formatted_variables) + formatted_variables.update(subfield.formatted_variables) return formatted_variables diff --git a/ariadne_codegen/client_generators/package.py b/ariadne_codegen/client_generators/package.py index b3adb052..a05cc594 100644 --- a/ariadne_codegen/client_generators/package.py +++ b/ariadne_codegen/client_generators/package.py @@ -398,11 +398,6 @@ def _generate_custom_fields_typing(self): code = self._add_comments_to_code(ast_to_str(module, False)) file_path.write_text(code) self._generated_files.append(file_path.name) - self.init_generator.add_import( - self.custom_fields_typing_generator.get_generated_public_names(), - self.custom_help_field_module_name, - 1, - ) def _generate_custom_fields(self): file_path = self.package_path / "custom_fields.py" diff --git a/tests/main/clients/custom_query_builder/expected_client/__init__.py b/tests/main/clients/custom_query_builder/expected_client/__init__.py index 9fb8977c..99fb757c 100644 --- a/tests/main/clients/custom_query_builder/expected_client/__init__.py +++ b/tests/main/clients/custom_query_builder/expected_client/__init__.py @@ -1,23 +1,6 @@ from .async_base_client import AsyncBaseClient from .base_model import BaseModel, Upload from .client import Client -from .custom_typing_fields import ( - AppGraphQLField, - CollectionTranslatableContentGraphQLField, - MetadataErrorGraphQLField, - MetadataItemGraphQLField, - ObjectWithMetadataGraphQLField, - PageInfoGraphQLField, - ProductCountableConnectionGraphQLField, - ProductCountableEdgeGraphQLField, - ProductGraphQLField, - ProductTranslatableContentGraphQLField, - ProductTypeCountableConnectionGraphQLField, - TranslatableItemConnectionGraphQLField, - TranslatableItemEdgeGraphQLField, - TranslatableItemUnion, - UpdateMetadataGraphQLField, -) from .enums import MetadataErrorCode from .exceptions import ( GraphQLClientError, @@ -28,29 +11,14 @@ ) __all__ = [ - "AppGraphQLField", "AsyncBaseClient", "BaseModel", "Client", - "CollectionTranslatableContentGraphQLField", "GraphQLClientError", "GraphQLClientGraphQLError", "GraphQLClientGraphQLMultiError", "GraphQLClientHttpError", "GraphQLClientInvalidResponseError", "MetadataErrorCode", - "MetadataErrorGraphQLField", - "MetadataItemGraphQLField", - "ObjectWithMetadataGraphQLField", - "PageInfoGraphQLField", - "ProductCountableConnectionGraphQLField", - "ProductCountableEdgeGraphQLField", - "ProductGraphQLField", - "ProductTranslatableContentGraphQLField", - "ProductTypeCountableConnectionGraphQLField", - "TranslatableItemConnectionGraphQLField", - "TranslatableItemEdgeGraphQLField", - "TranslatableItemUnion", - "UpdateMetadataGraphQLField", "Upload", ] diff --git a/tests/main/clients/custom_query_builder/expected_client/base_operation.py b/tests/main/clients/custom_query_builder/expected_client/base_operation.py index 0695b558..9f9c7660 100644 --- a/tests/main/clients/custom_query_builder/expected_client/base_operation.py +++ b/tests/main/clients/custom_query_builder/expected_client/base_operation.py @@ -12,11 +12,16 @@ class GraphQLArgument: - def __init__(self, argument_name: str, argument_value: Any): + """ + Represents a GraphQL argument and allows conversion to an AST structure. + """ + + def __init__(self, argument_name: str, argument_value: Any) -> None: self._name = argument_name self._value = argument_value def to_ast(self) -> ArgumentNode: + """Converts the argument to an ArgumentNode AST object.""" return ArgumentNode( name=NameNode(value=self._name), value=VariableNode(name=NameNode(value=self._value)), @@ -24,6 +29,15 @@ def to_ast(self) -> ArgumentNode: class GraphQLField: + """ + Represents a GraphQL field with its name, arguments, subfields, alias, + and inline fragments. + + Attributes: + formatted_variables (Dict[str, Dict[str, Any]]): The formatted arguments + of the GraphQL field. + """ + def __init__( self, field_name: str, arguments: Optional[Dict[str, Dict[str, Any]]] = None ) -> None: @@ -35,24 +49,25 @@ def __init__( self._inline_fragments: Dict[str, Tuple[GraphQLField, ...]] = {} def alias(self, alias: str) -> "GraphQLField": + """Sets an alias for the GraphQL field and returns the instance.""" self._alias = alias return self - def add_subfield(self, subfield: "GraphQLField") -> None: - self._subfields.append(subfield) - - def add_inline_fragment(self, type_name: str, *subfields: "GraphQLField") -> None: - self._inline_fragments[type_name] = subfields - def _build_field_name(self) -> str: + """Builds the field name, including the alias if present.""" return f"{self._alias}: {self._field_name}" if self._alias else self._field_name def _build_selections( self, idx: int, used_names: Set[str] ) -> List[Union[FieldNode, InlineFragmentNode]]: + """Builds the selection set for the current GraphQL field, + including subfields and inline fragments.""" + # Create selections from subfields selections: List[Union[FieldNode, InlineFragmentNode]] = [ subfield.to_ast(idx, used_names) for subfield in self._subfields ] + + # Add inline fragments for name, subfields in self._inline_fragments.items(): selections.append( InlineFragmentNode( @@ -64,22 +79,35 @@ def _build_selections( ), ) ) + return selections def _format_variable_name( self, idx: int, var_name: str, used_names: Set[str] ) -> str: + """Generates a unique variable name by appending an index and, + if necessary, an additional counter to avoid duplicates.""" base_name = f"{var_name}_{idx}" unique_name = base_name counter = 1 + + # Ensure the generated name is unique while unique_name in used_names: unique_name = f"{base_name}_{counter}" counter += 1 + + # Add the unique name to the set of used names used_names.add(unique_name) + return unique_name def _collect_all_variables(self, idx: int, used_names: Set[str]) -> None: + """ + Collects and formats all variables for the current GraphQL field, + ensuring unique names. + """ self.formatted_variables = {} + for k, v in self._variables.items(): unique_name = self._format_variable_name(idx, k, used_names) self.formatted_variables[unique_name] = { @@ -89,16 +117,18 @@ def _collect_all_variables(self, idx: int, used_names: Set[str]) -> None: } def to_ast(self, idx: int, used_names: Optional[Set[str]] = None) -> FieldNode: + """Converts the current GraphQL field to an AST (Abstract Syntax Tree) node.""" if used_names is None: used_names = set() + self._collect_all_variables(idx, used_names) - formatted_args = [ - GraphQLArgument(v["name"], k).to_ast() - for k, v in self.formatted_variables.items() - ] + return FieldNode( name=NameNode(value=self._build_field_name()), - arguments=formatted_args, + arguments=[ + GraphQLArgument(v["name"], k).to_ast() + for k, v in self.formatted_variables.items() + ], selection_set=( SelectionSetNode(selections=self._build_selections(idx, used_names)) if self._subfields or self._inline_fragments @@ -107,12 +137,20 @@ def to_ast(self, idx: int, used_names: Optional[Set[str]] = None) -> FieldNode: ) def get_formatted_variables(self) -> Dict[str, Dict[str, Any]]: - formatted_variables = self.formatted_variables + """ + Retrieves all formatted variables for the current GraphQL field, + including those from subfields and inline fragments. + """ + formatted_variables = self.formatted_variables.copy() + + # Collect variables from subfields for subfield in self._subfields: subfield.get_formatted_variables() - self.formatted_variables.update(subfield.formatted_variables) + formatted_variables.update(subfield.formatted_variables) + + # Collect variables from inline fragments for subfields in self._inline_fragments.values(): for subfield in subfields: subfield.get_formatted_variables() - self.formatted_variables.update(subfield.formatted_variables) + formatted_variables.update(subfield.formatted_variables) return formatted_variables diff --git a/tests/main/clients/custom_query_builder/expected_client/custom_fields.py b/tests/main/clients/custom_query_builder/expected_client/custom_fields.py index 310c5d8b..60ab1e41 100644 --- a/tests/main/clients/custom_query_builder/expected_client/custom_fields.py +++ b/tests/main/clients/custom_query_builder/expected_client/custom_fields.py @@ -1,6 +1,7 @@ -from typing import Any, Optional, Union +from typing import Any, Dict, Optional, Union -from . import ( +from .base_operation import GraphQLField +from .custom_typing_fields import ( AppGraphQLField, CollectionTranslatableContentGraphQLField, MetadataErrorGraphQLField, @@ -17,59 +18,62 @@ TranslatableItemUnion, UpdateMetadataGraphQLField, ) -from .base_operation import GraphQLField class AppFields(GraphQLField): - id: AppGraphQLField = AppGraphQLField("id") + id: "AppGraphQLField" = AppGraphQLField("id") def fields(self, *subfields: AppGraphQLField) -> "AppFields": + """Subfields should come from the AppFields class""" self._subfields.extend(subfields) return self class CollectionTranslatableContentFields(GraphQLField): - id: CollectionTranslatableContentGraphQLField = ( + id: "CollectionTranslatableContentGraphQLField" = ( CollectionTranslatableContentGraphQLField("id") ) - collection_id: CollectionTranslatableContentGraphQLField = ( + collection_id: "CollectionTranslatableContentGraphQLField" = ( CollectionTranslatableContentGraphQLField("collectionId") ) - seo_title: CollectionTranslatableContentGraphQLField = ( + seo_title: "CollectionTranslatableContentGraphQLField" = ( CollectionTranslatableContentGraphQLField("seoTitle") ) - seo_description: CollectionTranslatableContentGraphQLField = ( + seo_description: "CollectionTranslatableContentGraphQLField" = ( CollectionTranslatableContentGraphQLField("seoDescription") ) - name: CollectionTranslatableContentGraphQLField = ( + name: "CollectionTranslatableContentGraphQLField" = ( CollectionTranslatableContentGraphQLField("name") ) - description: CollectionTranslatableContentGraphQLField = ( + description: "CollectionTranslatableContentGraphQLField" = ( CollectionTranslatableContentGraphQLField("description") ) def fields( self, *subfields: CollectionTranslatableContentGraphQLField ) -> "CollectionTranslatableContentFields": + """Subfields should come from the CollectionTranslatableContentFields class""" self._subfields.extend(subfields) return self class MetadataErrorFields(GraphQLField): - field: MetadataErrorGraphQLField = MetadataErrorGraphQLField("field") - message: MetadataErrorGraphQLField = MetadataErrorGraphQLField("message") - code: MetadataErrorGraphQLField = MetadataErrorGraphQLField("code") + field: "MetadataErrorGraphQLField" = MetadataErrorGraphQLField("field") + message: "MetadataErrorGraphQLField" = MetadataErrorGraphQLField("message") + code: "MetadataErrorGraphQLField" = MetadataErrorGraphQLField("code") def fields(self, *subfields: MetadataErrorGraphQLField) -> "MetadataErrorFields": + """Subfields should come from the MetadataErrorFields class""" self._subfields.extend(subfields) return self class MetadataItemFields(GraphQLField): - key: MetadataItemGraphQLField = MetadataItemGraphQLField("key") - value: MetadataItemGraphQLField = MetadataItemGraphQLField("value") + key: "MetadataItemGraphQLField" = MetadataItemGraphQLField("key") + value: "MetadataItemGraphQLField" = MetadataItemGraphQLField("value") def fields(self, *subfields: MetadataItemGraphQLField) -> "MetadataItemFields": + """Subfields should come from the MetadataItemFields class""" self._subfields.extend(subfields) return self @@ -77,27 +81,38 @@ def fields(self, *subfields: MetadataItemGraphQLField) -> "MetadataItemFields": class ObjectWithMetadataInterface(GraphQLField): @classmethod def private_metadata(cls) -> "MetadataItemFields": - return MetadataItemFields("private_metadata", arguments={}) + return MetadataItemFields("private_metadata") @classmethod def private_metafield(cls, key: str) -> "ObjectWithMetadataGraphQLField": + arguments: Dict[str, Dict[str, Any]] = { + "key": {"type": "String!", "value": key} + } + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } return ObjectWithMetadataGraphQLField( - "private_metafield", arguments={"key": {"type": "String!", "value": key}} + "private_metafield", arguments=cleared_arguments ) @classmethod def metadata(cls) -> "MetadataItemFields": - return MetadataItemFields("metadata", arguments={}) + return MetadataItemFields("metadata") @classmethod def metafield(cls, key: str) -> "ObjectWithMetadataGraphQLField": - return ObjectWithMetadataGraphQLField( - "metafield", arguments={"key": {"type": "String!", "value": key}} - ) + arguments: Dict[str, Dict[str, Any]] = { + "key": {"type": "String!", "value": key} + } + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return ObjectWithMetadataGraphQLField("metafield", arguments=cleared_arguments) def fields( self, *subfields: Union[ObjectWithMetadataGraphQLField, "MetadataItemFields"] ) -> "ObjectWithMetadataInterface": + """Subfields should come from the ObjectWithMetadataInterface class""" self._subfields.extend(subfields) return self @@ -109,44 +124,54 @@ def on( class PageInfoFields(GraphQLField): - has_next_page: PageInfoGraphQLField = PageInfoGraphQLField("hasNextPage") - has_previous_page: PageInfoGraphQLField = PageInfoGraphQLField("hasPreviousPage") - start_cursor: PageInfoGraphQLField = PageInfoGraphQLField("startCursor") - end_cursor: PageInfoGraphQLField = PageInfoGraphQLField("endCursor") + has_next_page: "PageInfoGraphQLField" = PageInfoGraphQLField("hasNextPage") + has_previous_page: "PageInfoGraphQLField" = PageInfoGraphQLField("hasPreviousPage") + start_cursor: "PageInfoGraphQLField" = PageInfoGraphQLField("startCursor") + end_cursor: "PageInfoGraphQLField" = PageInfoGraphQLField("endCursor") def fields(self, *subfields: PageInfoGraphQLField) -> "PageInfoFields": + """Subfields should come from the PageInfoFields class""" self._subfields.extend(subfields) return self class ProductFields(GraphQLField): - id: ProductGraphQLField = ProductGraphQLField("id") - slug: ProductGraphQLField = ProductGraphQLField("slug") - name: ProductGraphQLField = ProductGraphQLField("name") + id: "ProductGraphQLField" = ProductGraphQLField("id") + slug: "ProductGraphQLField" = ProductGraphQLField("slug") + name: "ProductGraphQLField" = ProductGraphQLField("name") @classmethod def private_metadata(cls) -> "MetadataItemFields": - return MetadataItemFields("private_metadata", arguments={}) + return MetadataItemFields("private_metadata") @classmethod def private_metafield(cls, key: str) -> "ProductGraphQLField": - return ProductGraphQLField( - "private_metafield", arguments={"key": {"type": "String!", "value": key}} - ) + arguments: Dict[str, Dict[str, Any]] = { + "key": {"type": "String!", "value": key} + } + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return ProductGraphQLField("private_metafield", arguments=cleared_arguments) @classmethod def metadata(cls) -> "MetadataItemFields": - return MetadataItemFields("metadata", arguments={}) + return MetadataItemFields("metadata") @classmethod def metafield(cls, key: str) -> "ProductGraphQLField": - return ProductGraphQLField( - "metafield", arguments={"key": {"type": "String!", "value": key}} - ) + arguments: Dict[str, Dict[str, Any]] = { + "key": {"type": "String!", "value": key} + } + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return ProductGraphQLField("metafield", arguments=cleared_arguments) def fields( self, *subfields: Union[ProductGraphQLField, "MetadataItemFields"] ) -> "ProductFields": + """Subfields should come from the ProductFields class""" self._subfields.extend(subfields) return self @@ -154,13 +179,13 @@ def fields( class ProductCountableConnectionFields(GraphQLField): @classmethod def edges(cls) -> "ProductCountableEdgeFields": - return ProductCountableEdgeFields("edges", arguments={}) + return ProductCountableEdgeFields("edges") @classmethod def page_info(cls) -> "PageInfoFields": - return PageInfoFields("page_info", arguments={}) + return PageInfoFields("page_info") - total_count: ProductCountableConnectionGraphQLField = ( + total_count: "ProductCountableConnectionGraphQLField" = ( ProductCountableConnectionGraphQLField("totalCount") ) @@ -172,6 +197,7 @@ def fields( "ProductCountableEdgeFields", ] ) -> "ProductCountableConnectionFields": + """Subfields should come from the ProductCountableConnectionFields class""" self._subfields.extend(subfields) return self @@ -179,42 +205,44 @@ def fields( class ProductCountableEdgeFields(GraphQLField): @classmethod def node(cls) -> "ProductFields": - return ProductFields("node", arguments={}) + return ProductFields("node") - cursor: ProductCountableEdgeGraphQLField = ProductCountableEdgeGraphQLField( + cursor: "ProductCountableEdgeGraphQLField" = ProductCountableEdgeGraphQLField( "cursor" ) def fields( self, *subfields: Union[ProductCountableEdgeGraphQLField, "ProductFields"] ) -> "ProductCountableEdgeFields": + """Subfields should come from the ProductCountableEdgeFields class""" self._subfields.extend(subfields) return self class ProductTranslatableContentFields(GraphQLField): - id: ProductTranslatableContentGraphQLField = ProductTranslatableContentGraphQLField( - "id" + id: "ProductTranslatableContentGraphQLField" = ( + ProductTranslatableContentGraphQLField("id") ) - product_id: ProductTranslatableContentGraphQLField = ( + product_id: "ProductTranslatableContentGraphQLField" = ( ProductTranslatableContentGraphQLField("productId") ) - seo_title: ProductTranslatableContentGraphQLField = ( + seo_title: "ProductTranslatableContentGraphQLField" = ( ProductTranslatableContentGraphQLField("seoTitle") ) - seo_description: ProductTranslatableContentGraphQLField = ( + seo_description: "ProductTranslatableContentGraphQLField" = ( ProductTranslatableContentGraphQLField("seoDescription") ) - name: ProductTranslatableContentGraphQLField = ( + name: "ProductTranslatableContentGraphQLField" = ( ProductTranslatableContentGraphQLField("name") ) - description: ProductTranslatableContentGraphQLField = ( + description: "ProductTranslatableContentGraphQLField" = ( ProductTranslatableContentGraphQLField("description") ) def fields( self, *subfields: ProductTranslatableContentGraphQLField ) -> "ProductTranslatableContentFields": + """Subfields should come from the ProductTranslatableContentFields class""" self._subfields.extend(subfields) return self @@ -222,12 +250,13 @@ def fields( class ProductTypeCountableConnectionFields(GraphQLField): @classmethod def page_info(cls) -> "PageInfoFields": - return PageInfoFields("page_info", arguments={}) + return PageInfoFields("page_info") def fields( self, *subfields: Union[ProductTypeCountableConnectionGraphQLField, "PageInfoFields"] ) -> "ProductTypeCountableConnectionFields": + """Subfields should come from the ProductTypeCountableConnectionFields class""" self._subfields.extend(subfields) return self @@ -235,13 +264,13 @@ def fields( class TranslatableItemConnectionFields(GraphQLField): @classmethod def page_info(cls) -> "PageInfoFields": - return PageInfoFields("page_info", arguments={}) + return PageInfoFields("page_info") @classmethod def edges(cls) -> "TranslatableItemEdgeFields": - return TranslatableItemEdgeFields("edges", arguments={}) + return TranslatableItemEdgeFields("edges") - total_count: TranslatableItemConnectionGraphQLField = ( + total_count: "TranslatableItemConnectionGraphQLField" = ( TranslatableItemConnectionGraphQLField("totalCount") ) @@ -253,13 +282,14 @@ def fields( "TranslatableItemEdgeFields", ] ) -> "TranslatableItemConnectionFields": + """Subfields should come from the TranslatableItemConnectionFields class""" self._subfields.extend(subfields) return self class TranslatableItemEdgeFields(GraphQLField): - node: TranslatableItemUnion = TranslatableItemUnion("node") - cursor: TranslatableItemEdgeGraphQLField = TranslatableItemEdgeGraphQLField( + node: "TranslatableItemUnion" = TranslatableItemUnion("node") + cursor: "TranslatableItemEdgeGraphQLField" = TranslatableItemEdgeGraphQLField( "cursor" ) @@ -267,6 +297,7 @@ def fields( self, *subfields: Union[TranslatableItemEdgeGraphQLField, "TranslatableItemUnion"] ) -> "TranslatableItemEdgeFields": + """Subfields should come from the TranslatableItemEdgeFields class""" self._subfields.extend(subfields) return self @@ -274,15 +305,15 @@ def fields( class UpdateMetadataFields(GraphQLField): @classmethod def metadata_errors(cls) -> "MetadataErrorFields": - return MetadataErrorFields("metadata_errors", arguments={}) + return MetadataErrorFields("metadata_errors") @classmethod def errors(cls) -> "MetadataErrorFields": - return MetadataErrorFields("errors", arguments={}) + return MetadataErrorFields("errors") @classmethod def item(cls) -> "ObjectWithMetadataInterface": - return ObjectWithMetadataInterface("item", arguments={}) + return ObjectWithMetadataInterface("item") def fields( self, @@ -292,5 +323,6 @@ def fields( "ObjectWithMetadataInterface", ] ) -> "UpdateMetadataFields": + """Subfields should come from the UpdateMetadataFields class""" self._subfields.extend(subfields) return self diff --git a/tests/main/clients/custom_query_builder/expected_client/custom_queries.py b/tests/main/clients/custom_query_builder/expected_client/custom_queries.py index 90bfaf4f..d14cf4db 100644 --- a/tests/main/clients/custom_query_builder/expected_client/custom_queries.py +++ b/tests/main/clients/custom_query_builder/expected_client/custom_queries.py @@ -26,21 +26,11 @@ def products( @classmethod def app(cls) -> AppFields: - arguments: Dict[str, Dict[str, Any]] = {} - cleared_arguments = { - key: value for key, value in arguments.items() if value["value"] is not None - } - return AppFields(field_name="app", arguments=cleared_arguments) + return AppFields(field_name="app") @classmethod def product_types(cls) -> ProductTypeCountableConnectionFields: - arguments: Dict[str, Dict[str, Any]] = {} - cleared_arguments = { - key: value for key, value in arguments.items() if value["value"] is not None - } - return ProductTypeCountableConnectionFields( - field_name="productTypes", arguments=cleared_arguments - ) + return ProductTypeCountableConnectionFields(field_name="productTypes") @classmethod def translations( diff --git a/tests/main/custom_operation_builder/graphql_client/__init__.py b/tests/main/custom_operation_builder/graphql_client/__init__.py index 41117541..de6c91ca 100644 --- a/tests/main/custom_operation_builder/graphql_client/__init__.py +++ b/tests/main/custom_operation_builder/graphql_client/__init__.py @@ -1,14 +1,6 @@ from .async_base_client import AsyncBaseClient from .base_model import BaseModel, Upload from .client import Client -from .custom_typing_fields import ( - AdminGraphQLField, - GuestGraphQLField, - PersonInterfaceGraphQLField, - PostGraphQLField, - SearchResultUnion, - UserGraphQLField, -) from .enums import Role from .exceptions import ( GraphQLClientError, @@ -21,7 +13,6 @@ __all__ = [ "AddUserInput", - "AdminGraphQLField", "AsyncBaseClient", "BaseModel", "Client", @@ -30,12 +21,7 @@ "GraphQLClientGraphQLMultiError", "GraphQLClientHttpError", "GraphQLClientInvalidResponseError", - "GuestGraphQLField", - "PersonInterfaceGraphQLField", - "PostGraphQLField", "Role", - "SearchResultUnion", "UpdateUserInput", "Upload", - "UserGraphQLField", ] diff --git a/tests/main/custom_operation_builder/graphql_client/base_operation.py b/tests/main/custom_operation_builder/graphql_client/base_operation.py index 0695b558..9f9c7660 100644 --- a/tests/main/custom_operation_builder/graphql_client/base_operation.py +++ b/tests/main/custom_operation_builder/graphql_client/base_operation.py @@ -12,11 +12,16 @@ class GraphQLArgument: - def __init__(self, argument_name: str, argument_value: Any): + """ + Represents a GraphQL argument and allows conversion to an AST structure. + """ + + def __init__(self, argument_name: str, argument_value: Any) -> None: self._name = argument_name self._value = argument_value def to_ast(self) -> ArgumentNode: + """Converts the argument to an ArgumentNode AST object.""" return ArgumentNode( name=NameNode(value=self._name), value=VariableNode(name=NameNode(value=self._value)), @@ -24,6 +29,15 @@ def to_ast(self) -> ArgumentNode: class GraphQLField: + """ + Represents a GraphQL field with its name, arguments, subfields, alias, + and inline fragments. + + Attributes: + formatted_variables (Dict[str, Dict[str, Any]]): The formatted arguments + of the GraphQL field. + """ + def __init__( self, field_name: str, arguments: Optional[Dict[str, Dict[str, Any]]] = None ) -> None: @@ -35,24 +49,25 @@ def __init__( self._inline_fragments: Dict[str, Tuple[GraphQLField, ...]] = {} def alias(self, alias: str) -> "GraphQLField": + """Sets an alias for the GraphQL field and returns the instance.""" self._alias = alias return self - def add_subfield(self, subfield: "GraphQLField") -> None: - self._subfields.append(subfield) - - def add_inline_fragment(self, type_name: str, *subfields: "GraphQLField") -> None: - self._inline_fragments[type_name] = subfields - def _build_field_name(self) -> str: + """Builds the field name, including the alias if present.""" return f"{self._alias}: {self._field_name}" if self._alias else self._field_name def _build_selections( self, idx: int, used_names: Set[str] ) -> List[Union[FieldNode, InlineFragmentNode]]: + """Builds the selection set for the current GraphQL field, + including subfields and inline fragments.""" + # Create selections from subfields selections: List[Union[FieldNode, InlineFragmentNode]] = [ subfield.to_ast(idx, used_names) for subfield in self._subfields ] + + # Add inline fragments for name, subfields in self._inline_fragments.items(): selections.append( InlineFragmentNode( @@ -64,22 +79,35 @@ def _build_selections( ), ) ) + return selections def _format_variable_name( self, idx: int, var_name: str, used_names: Set[str] ) -> str: + """Generates a unique variable name by appending an index and, + if necessary, an additional counter to avoid duplicates.""" base_name = f"{var_name}_{idx}" unique_name = base_name counter = 1 + + # Ensure the generated name is unique while unique_name in used_names: unique_name = f"{base_name}_{counter}" counter += 1 + + # Add the unique name to the set of used names used_names.add(unique_name) + return unique_name def _collect_all_variables(self, idx: int, used_names: Set[str]) -> None: + """ + Collects and formats all variables for the current GraphQL field, + ensuring unique names. + """ self.formatted_variables = {} + for k, v in self._variables.items(): unique_name = self._format_variable_name(idx, k, used_names) self.formatted_variables[unique_name] = { @@ -89,16 +117,18 @@ def _collect_all_variables(self, idx: int, used_names: Set[str]) -> None: } def to_ast(self, idx: int, used_names: Optional[Set[str]] = None) -> FieldNode: + """Converts the current GraphQL field to an AST (Abstract Syntax Tree) node.""" if used_names is None: used_names = set() + self._collect_all_variables(idx, used_names) - formatted_args = [ - GraphQLArgument(v["name"], k).to_ast() - for k, v in self.formatted_variables.items() - ] + return FieldNode( name=NameNode(value=self._build_field_name()), - arguments=formatted_args, + arguments=[ + GraphQLArgument(v["name"], k).to_ast() + for k, v in self.formatted_variables.items() + ], selection_set=( SelectionSetNode(selections=self._build_selections(idx, used_names)) if self._subfields or self._inline_fragments @@ -107,12 +137,20 @@ def to_ast(self, idx: int, used_names: Optional[Set[str]] = None) -> FieldNode: ) def get_formatted_variables(self) -> Dict[str, Dict[str, Any]]: - formatted_variables = self.formatted_variables + """ + Retrieves all formatted variables for the current GraphQL field, + including those from subfields and inline fragments. + """ + formatted_variables = self.formatted_variables.copy() + + # Collect variables from subfields for subfield in self._subfields: subfield.get_formatted_variables() - self.formatted_variables.update(subfield.formatted_variables) + formatted_variables.update(subfield.formatted_variables) + + # Collect variables from inline fragments for subfields in self._inline_fragments.values(): for subfield in subfields: subfield.get_formatted_variables() - self.formatted_variables.update(subfield.formatted_variables) + formatted_variables.update(subfield.formatted_variables) return formatted_variables diff --git a/tests/main/custom_operation_builder/graphql_client/custom_fields.py b/tests/main/custom_operation_builder/graphql_client/custom_fields.py index 2c0856fe..345273c9 100644 --- a/tests/main/custom_operation_builder/graphql_client/custom_fields.py +++ b/tests/main/custom_operation_builder/graphql_client/custom_fields.py @@ -1,65 +1,88 @@ -from typing import Union +from typing import Any, Dict, Optional, Union -from . import ( +from .base_operation import GraphQLField +from .custom_typing_fields import ( AdminGraphQLField, GuestGraphQLField, PersonInterfaceGraphQLField, PostGraphQLField, UserGraphQLField, ) -from .base_operation import GraphQLField class AdminFields(GraphQLField): - id: AdminGraphQLField = AdminGraphQLField("id") - name: AdminGraphQLField = AdminGraphQLField("name") - privileges: AdminGraphQLField = AdminGraphQLField("privileges") - email: AdminGraphQLField = AdminGraphQLField("email") - created_at: AdminGraphQLField = AdminGraphQLField("createdAt") + id: "AdminGraphQLField" = AdminGraphQLField("id") + name: "AdminGraphQLField" = AdminGraphQLField("name") + privileges: "AdminGraphQLField" = AdminGraphQLField("privileges") + email: "AdminGraphQLField" = AdminGraphQLField("email") + created_at: "AdminGraphQLField" = AdminGraphQLField("createdAt") @classmethod def metafield(cls, key: str) -> "AdminGraphQLField": - return AdminGraphQLField( - "metafield", arguments={"key": {"type": "String!", "value": key}} - ) + arguments: Dict[str, Dict[str, Any]] = { + "key": {"type": "String!", "value": key} + } + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return AdminGraphQLField("metafield", arguments=cleared_arguments) + + @classmethod + def custom_field(cls, *, key: Optional[str] = None) -> "AdminGraphQLField": + arguments: Dict[str, Dict[str, Any]] = {"key": {"type": "String", "value": key}} + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return AdminGraphQLField("custom_field", arguments=cleared_arguments) def fields(self, *subfields: AdminGraphQLField) -> "AdminFields": + """Subfields should come from the AdminFields class""" self._subfields.extend(subfields) return self class GuestFields(GraphQLField): - id: GuestGraphQLField = GuestGraphQLField("id") - name: GuestGraphQLField = GuestGraphQLField("name") - visit_count: GuestGraphQLField = GuestGraphQLField("visitCount") - email: GuestGraphQLField = GuestGraphQLField("email") - created_at: GuestGraphQLField = GuestGraphQLField("createdAt") + id: "GuestGraphQLField" = GuestGraphQLField("id") + name: "GuestGraphQLField" = GuestGraphQLField("name") + visit_count: "GuestGraphQLField" = GuestGraphQLField("visitCount") + email: "GuestGraphQLField" = GuestGraphQLField("email") + created_at: "GuestGraphQLField" = GuestGraphQLField("createdAt") @classmethod def metafield(cls, key: str) -> "GuestGraphQLField": - return GuestGraphQLField( - "metafield", arguments={"key": {"type": "String!", "value": key}} - ) + arguments: Dict[str, Dict[str, Any]] = { + "key": {"type": "String!", "value": key} + } + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return GuestGraphQLField("metafield", arguments=cleared_arguments) def fields(self, *subfields: GuestGraphQLField) -> "GuestFields": + """Subfields should come from the GuestFields class""" self._subfields.extend(subfields) return self class PersonInterfaceInterface(GraphQLField): - id: PersonInterfaceGraphQLField = PersonInterfaceGraphQLField("id") - name: PersonInterfaceGraphQLField = PersonInterfaceGraphQLField("name") - email: PersonInterfaceGraphQLField = PersonInterfaceGraphQLField("email") + id: "PersonInterfaceGraphQLField" = PersonInterfaceGraphQLField("id") + name: "PersonInterfaceGraphQLField" = PersonInterfaceGraphQLField("name") + email: "PersonInterfaceGraphQLField" = PersonInterfaceGraphQLField("email") @classmethod def metafield(cls, key: str) -> "PersonInterfaceGraphQLField": - return PersonInterfaceGraphQLField( - "metafield", arguments={"key": {"type": "String!", "value": key}} - ) + arguments: Dict[str, Dict[str, Any]] = { + "key": {"type": "String!", "value": key} + } + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return PersonInterfaceGraphQLField("metafield", arguments=cleared_arguments) def fields( self, *subfields: PersonInterfaceGraphQLField ) -> "PersonInterfaceInterface": + """Subfields should come from the PersonInterfaceInterface class""" self._subfields.extend(subfields) return self @@ -71,41 +94,47 @@ def on( class PostFields(GraphQLField): - id: PostGraphQLField = PostGraphQLField("id") - title: PostGraphQLField = PostGraphQLField("title") - content: PostGraphQLField = PostGraphQLField("content") + id: "PostGraphQLField" = PostGraphQLField("id") + title: "PostGraphQLField" = PostGraphQLField("title") + content: "PostGraphQLField" = PostGraphQLField("content") @classmethod def author(cls) -> "PersonInterfaceInterface": - return PersonInterfaceInterface("author", arguments={}) + return PersonInterfaceInterface("author") - published_at: PostGraphQLField = PostGraphQLField("publishedAt") + published_at: "PostGraphQLField" = PostGraphQLField("publishedAt") def fields( self, *subfields: Union[PostGraphQLField, "PersonInterfaceInterface"] ) -> "PostFields": + """Subfields should come from the PostFields class""" self._subfields.extend(subfields) return self class UserFields(GraphQLField): - id: UserGraphQLField = UserGraphQLField("id") - name: UserGraphQLField = UserGraphQLField("name") - age: UserGraphQLField = UserGraphQLField("age") - email: UserGraphQLField = UserGraphQLField("email") - role: UserGraphQLField = UserGraphQLField("role") - created_at: UserGraphQLField = UserGraphQLField("createdAt") + id: "UserGraphQLField" = UserGraphQLField("id") + name: "UserGraphQLField" = UserGraphQLField("name") + age: "UserGraphQLField" = UserGraphQLField("age") + email: "UserGraphQLField" = UserGraphQLField("email") + role: "UserGraphQLField" = UserGraphQLField("role") + created_at: "UserGraphQLField" = UserGraphQLField("createdAt") @classmethod def friends(cls) -> "UserFields": - return UserFields("friends", arguments={}) + return UserFields("friends") @classmethod def metafield(cls, key: str) -> "UserGraphQLField": - return UserGraphQLField( - "metafield", arguments={"key": {"type": "String!", "value": key}} - ) + arguments: Dict[str, Dict[str, Any]] = { + "key": {"type": "String!", "value": key} + } + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return UserGraphQLField("metafield", arguments=cleared_arguments) def fields(self, *subfields: Union[UserGraphQLField, "UserFields"]) -> "UserFields": + """Subfields should come from the UserFields class""" self._subfields.extend(subfields) return self diff --git a/tests/main/custom_operation_builder/graphql_client/custom_queries.py b/tests/main/custom_operation_builder/graphql_client/custom_queries.py index 71ee4d8a..353ed882 100644 --- a/tests/main/custom_operation_builder/graphql_client/custom_queries.py +++ b/tests/main/custom_operation_builder/graphql_client/custom_queries.py @@ -7,11 +7,7 @@ class Query: @classmethod def hello(cls) -> GraphQLField: - arguments: Dict[str, Dict[str, Any]] = {} - cleared_arguments = { - key: value for key, value in arguments.items() if value["value"] is not None - } - return GraphQLField(field_name="hello", arguments=cleared_arguments) + return GraphQLField(field_name="hello") @classmethod def greeting(cls, *, name: Optional[str] = None) -> GraphQLField: @@ -35,11 +31,7 @@ def user(cls, user_id: str) -> UserFields: @classmethod def users(cls) -> UserFields: - arguments: Dict[str, Dict[str, Any]] = {} - cleared_arguments = { - key: value for key, value in arguments.items() if value["value"] is not None - } - return UserFields(field_name="users", arguments=cleared_arguments) + return UserFields(field_name="users") @classmethod def search(cls, text: str) -> SearchResultUnion: @@ -53,11 +45,7 @@ def search(cls, text: str) -> SearchResultUnion: @classmethod def posts(cls) -> PostFields: - arguments: Dict[str, Dict[str, Any]] = {} - cleared_arguments = { - key: value for key, value in arguments.items() if value["value"] is not None - } - return PostFields(field_name="posts", arguments=cleared_arguments) + return PostFields(field_name="posts") @classmethod def person(cls, person_id: str) -> PersonInterfaceInterface: @@ -73,10 +61,4 @@ def person(cls, person_id: str) -> PersonInterfaceInterface: @classmethod def people(cls) -> PersonInterfaceInterface: - arguments: Dict[str, Dict[str, Any]] = {} - cleared_arguments = { - key: value for key, value in arguments.items() if value["value"] is not None - } - return PersonInterfaceInterface( - field_name="people", arguments=cleared_arguments - ) + return PersonInterfaceInterface(field_name="people") diff --git a/tests/main/custom_operation_builder/schema.graphql b/tests/main/custom_operation_builder/schema.graphql index 34206824..74be4fd9 100644 --- a/tests/main/custom_operation_builder/schema.graphql +++ b/tests/main/custom_operation_builder/schema.graphql @@ -34,6 +34,7 @@ type User implements PersonInterface { role: Role! createdAt: String friends: [User] + metafield(key: String!): String } type Admin implements PersonInterface { @@ -42,6 +43,8 @@ type Admin implements PersonInterface { privileges: [String!]! email: String! createdAt: String + metafield(key: String!): String + customField(key: String): String } type Guest implements PersonInterface { @@ -50,6 +53,7 @@ type Guest implements PersonInterface { visitCount: Int email: String! createdAt: String + metafield(key: String!): String } type Post { diff --git a/tests/main/custom_operation_builder/test_operation_build.py b/tests/main/custom_operation_builder/test_operation_build.py index 76d73740..cdd39657 100644 --- a/tests/main/custom_operation_builder/test_operation_build.py +++ b/tests/main/custom_operation_builder/test_operation_build.py @@ -20,56 +20,60 @@ def test_simple_hello(): def test_greeting_with_name(): - built_query = print_ast(Query.greeting(name="Alice").to_ast(0)) + query = Query.greeting(name="Alice") expected_query = "greeting(name: $name_0)" + + built_query = print_ast(query.to_ast(0)) + assert built_query == expected_query + assert query.get_formatted_variables() == { + "name_0": {"name": "name", "type": "String", "value": "Alice"} + } def test_user_by_id(): - built_query = print_ast( - Query.user(user_id="1") - .fields( - UserFields.id, - UserFields.name, - UserFields.age, - UserFields.email, - ) - .to_ast(0) + query = Query.user(user_id="1").fields( + UserFields.id, + UserFields.name, + UserFields.age, + UserFields.email, ) expected_query = "user(user_id: $user_id_0) {\n id\n name\n age\n email\n}" + + built_query = print_ast(query.to_ast(0)) + assert built_query == expected_query + assert query.get_formatted_variables() == { + "user_id_0": {"name": "user_id", "type": "ID!", "value": "1"} + } def test_all_users(): - built_query = print_ast( - Query.users() - .fields( - UserFields.id, - UserFields.name, - UserFields.age, - UserFields.email, - ) - .to_ast(0) + query = Query.users().fields( + UserFields.id, + UserFields.name, + UserFields.age, + UserFields.email, ) expected_query = "users {\n id\n name\n age\n email\n}" + + built_query = print_ast(query.to_ast(0)) + assert built_query == expected_query + assert not query.get_formatted_variables() def test_user_with_friends(): - built_query = print_ast( - Query.user(user_id="1") - .fields( + query = Query.user(user_id="1").fields( + UserFields.id, + UserFields.name, + UserFields.age, + UserFields.email, + UserFields.friends().fields( UserFields.id, UserFields.name, - UserFields.age, - UserFields.email, - UserFields.friends().fields( - UserFields.id, - UserFields.name, - ), - UserFields.created_at, - ) - .to_ast(0) + ), + UserFields.created_at, ) expected_query = ( "user(user_id: $user_id_0) {\n" @@ -84,11 +88,17 @@ def test_user_with_friends(): " createdAt\n" "}" ) + + built_query = print_ast(query.to_ast(0)) + assert built_query == expected_query + assert query.get_formatted_variables() == { + "user_id_0": {"name": "user_id", "type": "ID!", "value": "1"} + } def test_search_example(): - built_query = print_ast( + query = ( Query.search(text="example") .on( "User", @@ -111,7 +121,6 @@ def test_search_example(): GuestFields.visit_count, GuestFields.created_at, ) - .to_ast(0) ) expected_query = ( "search(text: $text_0) {\n" @@ -135,24 +144,26 @@ def test_search_example(): " }\n" "}" ) + + built_query = print_ast(query.to_ast(0)) + assert built_query == expected_query + assert query.get_formatted_variables() == { + "text_0": {"name": "text", "type": "String!", "value": "example"} + } def test_posts_with_authors(): - built_query = print_ast( - Query.posts() - .fields( - PostFields.id, - PostFields.title, - PostFields.content, - PostFields.author().fields( - PersonInterfaceInterface.id, - PersonInterfaceInterface.name, - PersonInterfaceInterface.email, - ), - PostFields.published_at, - ) - .to_ast(0) + query = Query.posts().fields( + PostFields.id, + PostFields.title, + PostFields.content, + PostFields.author().fields( + PersonInterfaceInterface.id, + PersonInterfaceInterface.name, + PersonInterfaceInterface.email, + ), + PostFields.published_at, ) expected_query = ( "posts {\n" @@ -167,20 +178,33 @@ def test_posts_with_authors(): " publishedAt\n" "}" ) + + built_query = print_ast(query.to_ast(0)) + assert built_query == expected_query + assert not query.get_formatted_variables() def test_get_person(): - built_query = print_ast( + query = ( Query.person(person_id="1") .fields( PersonInterfaceInterface.id, PersonInterfaceInterface.name, PersonInterfaceInterface.email, ) - .on("User", UserFields.age, UserFields.role) - .on("Admin", AdminFields.privileges) - .to_ast(0) + .on( + "User", + UserFields.age, + UserFields.role, + UserFields.metafield(key="meta"), + ) + .on( + "Admin", + AdminFields.privileges, + AdminFields.custom_field(), + AdminFields.metafield(key="meta"), + ) ) expected_query = ( "person(person_id: $person_id_0) {\n" @@ -190,17 +214,28 @@ def test_get_person(): " ... on User {\n" " age\n" " role\n" + " metafield(key: $key_0)\n" " }\n" " ... on Admin {\n" " privileges\n" + " custom_field\n" + " metafield(key: $key_0_1)\n" " }\n" "}" ) + + built_query = print_ast(query.to_ast(0)) + assert built_query == expected_query + assert query.get_formatted_variables() == { + "person_id_0": {"name": "person_id", "type": "ID!", "value": "1"}, + "key_0": {"name": "key", "type": "String!", "value": "meta"}, + "key_0_1": {"name": "key", "type": "String!", "value": "meta"}, + } def test_get_people(): - built_query = print_ast( + query = ( Query.people() .fields( PersonInterfaceInterface.id, @@ -209,7 +244,6 @@ def test_get_people(): ) .on("User", UserFields.age, UserFields.role) .on("Admin", AdminFields.privileges) - .to_ast(0) ) expected_query = ( "people {\n" @@ -225,19 +259,22 @@ def test_get_people(): " }\n" "}" ) + + built_query = print_ast(query.to_ast(0)) + assert built_query == expected_query + assert not query.get_formatted_variables() def test_add_user_mutation(): - mutation = Mutation.add_user( - user_input=AddUserInput( - name="bob", - age=30, - email="bob@example.com", - role=Role.ADMIN, - createdAt="2024-06-07T00:00:00.000Z", - ) - ).fields( + user_input = AddUserInput( + name="bob", + age=30, + email="bob@example.com", + role=Role.ADMIN, + createdAt="2024-06-07T00:00:00.000Z", + ) + mutation = Mutation.add_user(user_input=user_input).fields( UserFields.id, UserFields.name, UserFields.age, @@ -245,7 +282,6 @@ def test_add_user_mutation(): UserFields.role, UserFields.created_at, ) - built_mutation = print_ast(mutation.to_ast(0)) expected_mutation = ( "addUser(" "user_input: $user_input_0" @@ -258,43 +294,37 @@ def test_add_user_mutation(): " createdAt\n" "}" ) + + built_mutation = print_ast(mutation.to_ast(0)) + assert built_mutation == expected_mutation assert mutation.get_formatted_variables() == { "user_input_0": { "name": "user_input", "type": "AddUserInput!", - "value": AddUserInput( - name="bob", - age=30, - email="bob@example.com", - role=Role.ADMIN, - created_at="2024-06-07T00:00:00.000Z", - ), + "value": user_input, } } def test_update_user_mutation(): - built_mutation = print_ast( - Mutation.update_user( - user_id="1", - user_input=UpdateUserInput( - name="Alice Updated", - age=25, - email="alice.updated@example.com", - role=Role.USER, - createdAt="2024-06-07T00:00:00.000Z", - ), - ) - .fields( - UserFields.id, - UserFields.name, - UserFields.age, - UserFields.email, - UserFields.role, - UserFields.created_at, - ) - .to_ast(0) + user_input = UpdateUserInput( + name="Alice Updated", + age=25, + email="alice.updated@example.com", + role=Role.USER, + createdAt="2024-06-07T00:00:00.000Z", + ) + mutation = Mutation.update_user( + user_id="1", + user_input=user_input, + ).fields( + UserFields.id, + UserFields.name, + UserFields.age, + UserFields.email, + UserFields.role, + UserFields.created_at, ) expected_mutation = ( "updateUser(" @@ -309,41 +339,50 @@ def test_update_user_mutation(): " createdAt\n" "}" ) + + built_mutation = print_ast(mutation.to_ast(0)) + assert built_mutation == expected_mutation + assert mutation.get_formatted_variables() == { + "user_id_0": {"name": "user_id", "type": "ID!", "value": "1"}, + "user_input_0": { + "name": "user_input", + "type": "UpdateUserInput!", + "value": user_input, + }, + } def test_delete_user_mutation(): - built_mutation = print_ast( - Mutation.delete_user(user_id="1") - .fields( - UserFields.id, - UserFields.name, - ) - .to_ast(0) + mutation = Mutation.delete_user(user_id="1").fields( + UserFields.id, + UserFields.name, ) expected_mutation = "deleteUser(user_id: $user_id_0) {\n id\n name\n}" + + built_mutation = print_ast(mutation.to_ast(0)) + assert built_mutation == expected_mutation + assert mutation.get_formatted_variables() == { + "user_id_0": {"name": "user_id", "type": "ID!", "value": "1"} + } def test_add_post_mutation(): - built_mutation = print_ast( - Mutation.add_post( - title="New Post", - content="This is the content", - author_id="1", - published_at="2024-06-07T00:00:00.000Z", - ) - .fields( - PostFields.id, - PostFields.title, - PostFields.content, - PostFields.author().fields( - PersonInterfaceInterface.id, - PersonInterfaceInterface.name, - ), - PostFields.published_at, - ) - .to_ast(0) + mutation = Mutation.add_post( + title="New Post", + content="This is the content", + author_id="1", + published_at="2024-06-07T00:00:00.000Z", + ).fields( + PostFields.id, + PostFields.title, + PostFields.content, + PostFields.author().fields( + PersonInterfaceInterface.id, + PersonInterfaceInterface.name, + ), + PostFields.published_at, ) expected_mutation = ( "addPost(\n" @@ -362,24 +401,45 @@ def test_add_post_mutation(): " publishedAt\n" "}" ) + + built_mutation = print_ast(mutation.to_ast(0)) + assert built_mutation == expected_mutation + assert mutation.get_formatted_variables() == { + "title_0": { + "name": "title", + "type": "String!", + "value": "New Post", + }, + "content_0": { + "name": "content", + "type": "String!", + "value": "This is the content", + }, + "authorId_0": { + "name": "authorId", + "type": "ID!", + "value": "1", + }, + "publishedAt_0": { + "name": "publishedAt", + "type": "String!", + "value": "2024-06-07T00:00:00.000Z", + }, + } def test_update_post_mutation(): - built_mutation = print_ast( - Mutation.update_post( - post_id="1", - title="Updated Title", - content="Updated Content", - published_at="2024-06-07T00:00:00.000Z", - ) - .fields( - PostFields.id, - PostFields.title, - PostFields.content, - PostFields.published_at, - ) - .to_ast(0) + mutation = Mutation.update_post( + post_id="1", + title="Updated Title", + content="Updated Content", + published_at="2024-06-07T00:00:00.000Z", + ).fields( + PostFields.id, + PostFields.title, + PostFields.content, + PostFields.published_at, ) expected_mutation = ( "updatePost(\n" @@ -394,40 +454,55 @@ def test_update_post_mutation(): " publishedAt\n" "}" ) + + built_mutation = print_ast(mutation.to_ast(0)) + assert built_mutation == expected_mutation + assert mutation.get_formatted_variables() == { + "post_id_0": {"name": "post_id", "type": "ID!", "value": "1"}, + "title_0": {"name": "title", "type": "String", "value": "Updated Title"}, + "content_0": {"name": "content", "type": "String", "value": "Updated Content"}, + "publishedAt_0": { + "name": "publishedAt", + "type": "String", + "value": "2024-06-07T00:00:00.000Z", + }, + } def test_delete_post_mutation(): - built_mutation = print_ast( - Mutation.delete_post(post_id="1") - .fields( - PostFields.id, - PostFields.title, - ) - .to_ast(0) + mutation = Mutation.delete_post(post_id="1").fields( + PostFields.id, + PostFields.title, ) expected_mutation = "deletePost(post_id: $post_id_0) {\n id\n title\n}" + + built_mutation = print_ast(mutation.to_ast(0)) + assert built_mutation == expected_mutation + assert mutation.get_formatted_variables() == { + "post_id_0": {"name": "post_id", "type": "ID!", "value": "1"} + } def test_user_specific_fields(): - built_query = print_ast( - Query.user(user_id="1").fields(UserFields.id, UserFields.name).to_ast(0) - ) + query = Query.user(user_id="1").fields(UserFields.id, UserFields.name) expected_query = "user(user_id: $user_id_0) {\n id\n name\n}" + + built_query = print_ast(query.to_ast(0)) + assert built_query == expected_query + assert query.get_formatted_variables() == { + "user_id_0": {"name": "user_id", "type": "ID!", "value": "1"} + } def test_user_with_friends_specific_fields(): - built_query = print_ast( - Query.user(user_id="1") - .fields( - UserFields.id, - UserFields.name, - UserFields.friends().fields(UserFields.id, UserFields.name), - UserFields.created_at, - ) - .to_ast(0) + query = Query.user(user_id="1").fields( + UserFields.id, + UserFields.name, + UserFields.friends().fields(UserFields.id, UserFields.name), + UserFields.created_at, ) expected_query = ( "user(user_id: $user_id_0) {\n" @@ -440,7 +515,13 @@ def test_user_with_friends_specific_fields(): " createdAt\n" "}" ) + + built_query = print_ast(query.to_ast(0)) + assert built_query == expected_query + assert query.get_formatted_variables() == { + "user_id_0": {"name": "user_id", "type": "ID!", "value": "1"} + } def test_people_with_metadata(): @@ -455,7 +536,6 @@ def test_people_with_metadata(): ) .on("User", UserFields.age, UserFields.role) ) - built_query = print_ast(query.to_ast(0)) expected_query = ( "people {\n" " id\n" @@ -469,4 +549,11 @@ def test_people_with_metadata(): " }\n" "}" ) + + built_query = print_ast(query.to_ast(0)) + assert built_query == expected_query + assert query.get_formatted_variables() == { + "key_0": {"name": "key", "type": "String!", "value": "bio"}, + "key_0_1": {"name": "key", "type": "String!", "value": "ots"}, + } From 99e0874f7421f9fcf3e6869c806e09ecbeb2b5a9 Mon Sep 17 00:00:00 2001 From: Damian Czajkowski Date: Wed, 17 Jul 2024 14:48:10 +0200 Subject: [PATCH 10/11] Fix for aliases --- .../client_generators/custom_fields.py | 26 ++++++++ .../client_generators/custom_fields_typing.py | 27 ++++++++- .../expected_client/custom_fields.py | 56 +++++++++++++++++ .../expected_client/custom_typing_fields.py | 60 ++++++++++++++----- .../graphql_client/custom_fields.py | 20 +++++++ .../graphql_client/custom_typing_fields.py | 24 ++++++-- .../test_operation_build.py | 6 ++ 7 files changed, 199 insertions(+), 20 deletions(-) diff --git a/ariadne_codegen/client_generators/custom_fields.py b/ariadne_codegen/client_generators/custom_fields.py index 78cbad51..0c0ba369 100644 --- a/ariadne_codegen/client_generators/custom_fields.py +++ b/ariadne_codegen/client_generators/custom_fields.py @@ -160,6 +160,7 @@ def _generate_class_def_body( class_name, definition.name, sorted(additional_fields_typing) ) ) + class_def.body.append(self._generate_alias_method(class_name)) return class_def def _get_combined_fields( @@ -361,3 +362,28 @@ def _get_suffix( if isinstance(graphql_type, GraphQLInterfaceType): return GRAPHQL_INTERFACE_SUFFIX raise ValueError(f"Unexpected graphql_type: {graphql_type}") + + def _generate_alias_method(self, class_name: str) -> ast.FunctionDef: + """ + Generates the `alias` method for a class. + """ + return generate_method_definition( + "alias", + arguments=generate_arguments( + [ + generate_arg(name="self"), + generate_arg(name="alias", annotation=generate_name("str")), + ] + ), + body=[ + ast.Assign( + targets=[ + generate_attribute(value=generate_name("self"), attr="_alias"), + ], + value=generate_name("alias"), + lineno=1, + ), + generate_return(value=generate_name("self")), + ], + return_type=generate_name(f'"{class_name}"'), + ) diff --git a/ariadne_codegen/client_generators/custom_fields_typing.py b/ariadne_codegen/client_generators/custom_fields_typing.py index 8c9fd6d3..54adb933 100644 --- a/ariadne_codegen/client_generators/custom_fields_typing.py +++ b/ariadne_codegen/client_generators/custom_fields_typing.py @@ -78,7 +78,7 @@ def _generate_field_class( if isinstance(graphql_type, GraphQLUnionType): class_name = f"{graphql_type.name}Union" class_body.append(self._generate_on_method(class_name)) - + class_body.append(self._generate_alias_method(class_name)) if class_name not in self._public_names: self._public_names.append(class_name) @@ -123,6 +123,31 @@ def _generate_on_method(self, class_name: str) -> ast.FunctionDef: return_type=generate_name(f'"{class_name}"'), ) + def _generate_alias_method(self, class_name: str) -> ast.FunctionDef: + """ + Generates the `alias` method for a class. + """ + return generate_method_definition( + "alias", + arguments=generate_arguments( + [ + generate_arg(name="self"), + generate_arg(name="alias", annotation=generate_name("str")), + ] + ), + body=[ + ast.Assign( + targets=[ + generate_attribute(value=generate_name("self"), attr="_alias"), + ], + value=generate_name("alias"), + lineno=1, + ), + generate_return(value=generate_name("self")), + ], + return_type=generate_name(f'"{class_name}"'), + ) + def get_generated_public_names(self) -> List[str]: """ Returns the list of generated public names. diff --git a/tests/main/clients/custom_query_builder/expected_client/custom_fields.py b/tests/main/clients/custom_query_builder/expected_client/custom_fields.py index 60ab1e41..799b8391 100644 --- a/tests/main/clients/custom_query_builder/expected_client/custom_fields.py +++ b/tests/main/clients/custom_query_builder/expected_client/custom_fields.py @@ -28,6 +28,10 @@ def fields(self, *subfields: AppGraphQLField) -> "AppFields": self._subfields.extend(subfields) return self + def alias(self, alias: str) -> "AppFields": + self._alias = alias + return self + class CollectionTranslatableContentFields(GraphQLField): id: "CollectionTranslatableContentGraphQLField" = ( @@ -56,6 +60,10 @@ def fields( self._subfields.extend(subfields) return self + def alias(self, alias: str) -> "CollectionTranslatableContentFields": + self._alias = alias + return self + class MetadataErrorFields(GraphQLField): field: "MetadataErrorGraphQLField" = MetadataErrorGraphQLField("field") @@ -67,6 +75,10 @@ def fields(self, *subfields: MetadataErrorGraphQLField) -> "MetadataErrorFields" self._subfields.extend(subfields) return self + def alias(self, alias: str) -> "MetadataErrorFields": + self._alias = alias + return self + class MetadataItemFields(GraphQLField): key: "MetadataItemGraphQLField" = MetadataItemGraphQLField("key") @@ -77,6 +89,10 @@ def fields(self, *subfields: MetadataItemGraphQLField) -> "MetadataItemFields": self._subfields.extend(subfields) return self + def alias(self, alias: str) -> "MetadataItemFields": + self._alias = alias + return self + class ObjectWithMetadataInterface(GraphQLField): @classmethod @@ -116,6 +132,10 @@ def fields( self._subfields.extend(subfields) return self + def alias(self, alias: str) -> "ObjectWithMetadataInterface": + self._alias = alias + return self + def on( self, type_name: str, *subfields: GraphQLField ) -> "ObjectWithMetadataInterface": @@ -134,6 +154,10 @@ def fields(self, *subfields: PageInfoGraphQLField) -> "PageInfoFields": self._subfields.extend(subfields) return self + def alias(self, alias: str) -> "PageInfoFields": + self._alias = alias + return self + class ProductFields(GraphQLField): id: "ProductGraphQLField" = ProductGraphQLField("id") @@ -175,6 +199,10 @@ def fields( self._subfields.extend(subfields) return self + def alias(self, alias: str) -> "ProductFields": + self._alias = alias + return self + class ProductCountableConnectionFields(GraphQLField): @classmethod @@ -201,6 +229,10 @@ def fields( self._subfields.extend(subfields) return self + def alias(self, alias: str) -> "ProductCountableConnectionFields": + self._alias = alias + return self + class ProductCountableEdgeFields(GraphQLField): @classmethod @@ -218,6 +250,10 @@ def fields( self._subfields.extend(subfields) return self + def alias(self, alias: str) -> "ProductCountableEdgeFields": + self._alias = alias + return self + class ProductTranslatableContentFields(GraphQLField): id: "ProductTranslatableContentGraphQLField" = ( @@ -246,6 +282,10 @@ def fields( self._subfields.extend(subfields) return self + def alias(self, alias: str) -> "ProductTranslatableContentFields": + self._alias = alias + return self + class ProductTypeCountableConnectionFields(GraphQLField): @classmethod @@ -260,6 +300,10 @@ def fields( self._subfields.extend(subfields) return self + def alias(self, alias: str) -> "ProductTypeCountableConnectionFields": + self._alias = alias + return self + class TranslatableItemConnectionFields(GraphQLField): @classmethod @@ -286,6 +330,10 @@ def fields( self._subfields.extend(subfields) return self + def alias(self, alias: str) -> "TranslatableItemConnectionFields": + self._alias = alias + return self + class TranslatableItemEdgeFields(GraphQLField): node: "TranslatableItemUnion" = TranslatableItemUnion("node") @@ -301,6 +349,10 @@ def fields( self._subfields.extend(subfields) return self + def alias(self, alias: str) -> "TranslatableItemEdgeFields": + self._alias = alias + return self + class UpdateMetadataFields(GraphQLField): @classmethod @@ -326,3 +378,7 @@ def fields( """Subfields should come from the UpdateMetadataFields class""" self._subfields.extend(subfields) return self + + def alias(self, alias: str) -> "UpdateMetadataFields": + self._alias = alias + return self diff --git a/tests/main/clients/custom_query_builder/expected_client/custom_typing_fields.py b/tests/main/clients/custom_query_builder/expected_client/custom_typing_fields.py index 8d8f7d7d..91be62aa 100644 --- a/tests/main/clients/custom_query_builder/expected_client/custom_typing_fields.py +++ b/tests/main/clients/custom_query_builder/expected_client/custom_typing_fields.py @@ -2,51 +2,75 @@ class ProductGraphQLField(GraphQLField): - pass + def alias(self, alias: str) -> "ProductGraphQLField": + self._alias = alias + return self class ProductCountableEdgeGraphQLField(GraphQLField): - pass + def alias(self, alias: str) -> "ProductCountableEdgeGraphQLField": + self._alias = alias + return self class ProductCountableConnectionGraphQLField(GraphQLField): - pass + def alias(self, alias: str) -> "ProductCountableConnectionGraphQLField": + self._alias = alias + return self class AppGraphQLField(GraphQLField): - pass + def alias(self, alias: str) -> "AppGraphQLField": + self._alias = alias + return self class ProductTypeCountableConnectionGraphQLField(GraphQLField): - pass + def alias(self, alias: str) -> "ProductTypeCountableConnectionGraphQLField": + self._alias = alias + return self class PageInfoGraphQLField(GraphQLField): - pass + def alias(self, alias: str) -> "PageInfoGraphQLField": + self._alias = alias + return self class ObjectWithMetadataGraphQLField(GraphQLField): - pass + def alias(self, alias: str) -> "ObjectWithMetadataGraphQLField": + self._alias = alias + return self class MetadataItemGraphQLField(GraphQLField): - pass + def alias(self, alias: str) -> "MetadataItemGraphQLField": + self._alias = alias + return self class UpdateMetadataGraphQLField(GraphQLField): - pass + def alias(self, alias: str) -> "UpdateMetadataGraphQLField": + self._alias = alias + return self class MetadataErrorGraphQLField(GraphQLField): - pass + def alias(self, alias: str) -> "MetadataErrorGraphQLField": + self._alias = alias + return self class TranslatableItemConnectionGraphQLField(GraphQLField): - pass + def alias(self, alias: str) -> "TranslatableItemConnectionGraphQLField": + self._alias = alias + return self class TranslatableItemEdgeGraphQLField(GraphQLField): - pass + def alias(self, alias: str) -> "TranslatableItemEdgeGraphQLField": + self._alias = alias + return self class TranslatableItemUnion(GraphQLField): @@ -54,10 +78,18 @@ def on(self, type_name: str, *subfields: GraphQLField) -> "TranslatableItemUnion self._inline_fragments[type_name] = subfields return self + def alias(self, alias: str) -> "TranslatableItemUnion": + self._alias = alias + return self + class ProductTranslatableContentGraphQLField(GraphQLField): - pass + def alias(self, alias: str) -> "ProductTranslatableContentGraphQLField": + self._alias = alias + return self class CollectionTranslatableContentGraphQLField(GraphQLField): - pass + def alias(self, alias: str) -> "CollectionTranslatableContentGraphQLField": + self._alias = alias + return self diff --git a/tests/main/custom_operation_builder/graphql_client/custom_fields.py b/tests/main/custom_operation_builder/graphql_client/custom_fields.py index 345273c9..d5f29221 100644 --- a/tests/main/custom_operation_builder/graphql_client/custom_fields.py +++ b/tests/main/custom_operation_builder/graphql_client/custom_fields.py @@ -40,6 +40,10 @@ def fields(self, *subfields: AdminGraphQLField) -> "AdminFields": self._subfields.extend(subfields) return self + def alias(self, alias: str) -> "AdminFields": + self._alias = alias + return self + class GuestFields(GraphQLField): id: "GuestGraphQLField" = GuestGraphQLField("id") @@ -63,6 +67,10 @@ def fields(self, *subfields: GuestGraphQLField) -> "GuestFields": self._subfields.extend(subfields) return self + def alias(self, alias: str) -> "GuestFields": + self._alias = alias + return self + class PersonInterfaceInterface(GraphQLField): id: "PersonInterfaceGraphQLField" = PersonInterfaceGraphQLField("id") @@ -86,6 +94,10 @@ def fields( self._subfields.extend(subfields) return self + def alias(self, alias: str) -> "PersonInterfaceInterface": + self._alias = alias + return self + def on( self, type_name: str, *subfields: GraphQLField ) -> "PersonInterfaceInterface": @@ -111,6 +123,10 @@ def fields( self._subfields.extend(subfields) return self + def alias(self, alias: str) -> "PostFields": + self._alias = alias + return self + class UserFields(GraphQLField): id: "UserGraphQLField" = UserGraphQLField("id") @@ -138,3 +154,7 @@ def fields(self, *subfields: Union[UserGraphQLField, "UserFields"]) -> "UserFiel """Subfields should come from the UserFields class""" self._subfields.extend(subfields) return self + + def alias(self, alias: str) -> "UserFields": + self._alias = alias + return self diff --git a/tests/main/custom_operation_builder/graphql_client/custom_typing_fields.py b/tests/main/custom_operation_builder/graphql_client/custom_typing_fields.py index 826f8f2a..b4dd2471 100644 --- a/tests/main/custom_operation_builder/graphql_client/custom_typing_fields.py +++ b/tests/main/custom_operation_builder/graphql_client/custom_typing_fields.py @@ -2,26 +2,40 @@ class PersonInterfaceGraphQLField(GraphQLField): - pass + def alias(self, alias: str) -> "PersonInterfaceGraphQLField": + self._alias = alias + return self class UserGraphQLField(GraphQLField): - pass + def alias(self, alias: str) -> "UserGraphQLField": + self._alias = alias + return self class AdminGraphQLField(GraphQLField): - pass + def alias(self, alias: str) -> "AdminGraphQLField": + self._alias = alias + return self class GuestGraphQLField(GraphQLField): - pass + def alias(self, alias: str) -> "GuestGraphQLField": + self._alias = alias + return self class PostGraphQLField(GraphQLField): - pass + def alias(self, alias: str) -> "PostGraphQLField": + self._alias = alias + return self class SearchResultUnion(GraphQLField): def on(self, type_name: str, *subfields: GraphQLField) -> "SearchResultUnion": self._inline_fragments[type_name] = subfields return self + + def alias(self, alias: str) -> "SearchResultUnion": + self._alias = alias + return self diff --git a/tests/main/custom_operation_builder/test_operation_build.py b/tests/main/custom_operation_builder/test_operation_build.py index cdd39657..19e4c83f 100644 --- a/tests/main/custom_operation_builder/test_operation_build.py +++ b/tests/main/custom_operation_builder/test_operation_build.py @@ -19,6 +19,12 @@ def test_simple_hello(): assert built_query == expected_query +def test_alias(): + built_query = print_ast(Query.hello().alias("aliased_hello").to_ast(0)) + expected_query = "aliased_hello: hello" + assert built_query == expected_query + + def test_greeting_with_name(): query = Query.greeting(name="Alice") expected_query = "greeting(name: $name_0)" From ca71d401d27da18bd81ebb807fe2d2d6fb76954a Mon Sep 17 00:00:00 2001 From: Damian Czajkowski Date: Wed, 17 Jul 2024 14:55:15 +0200 Subject: [PATCH 11/11] Bump version to 0.14.0 --- CHANGELOG.md | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7b9b9f9a..4630d466 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # CHANGELOG -## 0.14.0 (Unreleased) +## 0.14.0 (2024-07-17) - Added `ClientForwardRefsPlugin` to standard plugins. - Re-added `model_rebuild` calls for input types with forward references. diff --git a/pyproject.toml b/pyproject.toml index 70a91ae4..de4589c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "hatchling.build" name = "ariadne-codegen" description = "Generate fully typed GraphQL client from schema, queries and mutations!" authors = [{ name = "Mirumee Software", email = "hello@mirumee.com" }] -version = "0.14.0.dev2" +version = "0.14.0" readme = "README.md" license = { file = "LICENSE" } classifiers = [