diff --git a/CHANGELOG.md b/CHANGELOG.md index f4d7f2cb..4630d466 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,11 +1,12 @@ # CHANGELOG -## 0.14.0 (Unreleased) +## 0.14.0 (2024-07-17) - Added `ClientForwardRefsPlugin` to standard plugins. - Re-added `model_rebuild` calls for input types with forward references. - Fixed fragments on interfaces being omitted from generated client. - Fixed `@Include` directive result type when using `convert_to_snake_case` option. +- Added Custom query builder feature. ## 0.13.0 (2024-03-4) diff --git a/ariadne_codegen/client_generators/client.py b/ariadne_codegen/client_generators/client.py index 956bb9be..84caf3dc 100644 --- a/ariadne_codegen/client_generators/client.py +++ b/ariadne_codegen/client_generators/client.py @@ -14,10 +14,14 @@ generate_await, generate_call, generate_class_def, + generate_comp, generate_constant, + generate_dict, generate_expr, generate_import_from, generate_keyword, + generate_list, + generate_list_comp, generate_method_definition, generate_module, generate_name, @@ -32,15 +36,29 @@ from .constants import ( ANY, ASYNC_ITERATOR, + BASE_GRAPHQL_FIELD_CLASS_NAME, + BASE_OPERATION_FILE_PATH, DICT, + DOCUMENT_NODE, + GRAPHQL_MODULE, KWARGS_NAMES, LIST, MODEL_VALIDATE_METHOD, + NAME_NODE, + NAMED_TYPE_NODE, + OPERATION_DEFINITION_NODE, + OPERATION_TYPE, OPTIONAL, + PRINT_AST, + SELECTION_NODE, + SELECTION_SET_NODE, + TUPLE, TYPING_MODULE, UNION, UNSET_IMPORT, UPLOAD_IMPORT, + VARIABLE_DEFINITION_NODE, + VARIABLE_NODE, ) from .scalars import ScalarData, generate_scalar_imports @@ -66,10 +84,18 @@ def __init__( self.custom_scalars = custom_scalars if custom_scalars else {} self.arguments_generator = arguments_generator - self._imports: List[ast.ImportFrom] = [] + self._imports: List[Union[ast.ImportFrom, ast.Import]] = [] self._add_import( generate_import_from( - [OPTIONAL, LIST, DICT, ANY, UNION, ASYNC_ITERATOR], TYPING_MODULE + [ + OPTIONAL, + LIST, + DICT, + ANY, + UNION, + ASYNC_ITERATOR, + ], + TYPING_MODULE, ) ) self._add_import(base_client_import) @@ -187,6 +213,529 @@ def add_method( generate_import_from(names=[return_type], from_=return_type_module, level=1) ) + def create_combine_variables_method(self): + method_body = [ + generate_assign( + targets=["variables_types_combined"], + value=generate_dict(), + ), + generate_assign( + targets=["processed_variables_combined"], + value=generate_dict(), + ), + ast.For( + target=generate_name("field"), + iter=generate_name("fields"), + body=[ + generate_assign( + targets=["formatted_variables"], + value=generate_call( + func=generate_name("field.get_formatted_variables") + ), + ), + generate_expr( + value=generate_call( + func=generate_attribute( + value=generate_name("variables_types_combined"), + attr="update", + ), + args=[ + ast.DictComp( + key=generate_name("k"), + value=generate_name('v["type"]'), + generators=[ + generate_comp( + target="k, v", + iter_="formatted_variables.items()", + ) + ], + ) + ], + ) + ), + generate_expr( + value=generate_call( + func=generate_attribute( + value=generate_name("processed_variables_combined"), + attr="update", + ), + args=[ + ast.DictComp( + key=generate_name("k"), + value=generate_name('v["value"]'), + generators=[ + generate_comp( + target="k, v", + iter_="formatted_variables.items()", + ) + ], + ) + ], + ) + ), + ], + orelse=[], + lineno=1, + ), + generate_return( + value=generate_dict( + keys=[generate_constant("types"), generate_constant("values")], + values=[ + generate_name("variables_types_combined"), + generate_name("processed_variables_combined"), + ], + ) + ), + ] + + args = generate_arguments( + args=[ + generate_arg("self"), + generate_arg( + name="fields", + annotation=generate_subscript( + generate_name(TUPLE), + generate_tuple( + [ + generate_name("GraphQLField"), + generate_name("..."), + ] + ), + ), + ), + ], + ) + + returns = generate_subscript( + generate_name(DICT), + generate_tuple( + [ + generate_name("str"), + generate_subscript( + generate_name(DICT), + generate_tuple([generate_name("str"), generate_name("Any")]), + ), + ] + ), + ) + + method_def = generate_method_definition( + name="_combine_variables", + arguments=args, + body=method_body, + decorator_list=[], + return_type=returns, + ) + + return method_def + + def create_build_variable_definitions_method(self): + method_body = [ + generate_return( + value=generate_list_comp( + elt=generate_call( + func=generate_name("VariableDefinitionNode"), + keywords=[ + generate_keyword( + arg="variable", + value=generate_call( + func=generate_name("VariableNode"), + keywords=[ + generate_keyword( + arg="name", + value=generate_call( + func=generate_name("NameNode"), + keywords=[ + generate_keyword( + arg="value", + value=generate_name("var_name"), + ) + ], + ), + ) + ], + ), + ), + generate_keyword( + arg="type", + value=generate_call( + func=generate_name("NamedTypeNode"), + keywords=[ + generate_keyword( + arg="name", + value=generate_call( + func=generate_name("NameNode"), + keywords=[ + generate_keyword( + arg="value", + value=generate_name( + "var_value", + ), + ) + ], + ), + ) + ], + ), + ), + ], + ), + generators=[ + generate_comp( + target="var_name, var_value", + iter_="variables_types_combined.items()", + ) + ], + ) + ) + ] + return generate_method_definition( + name="_build_variable_definitions", + arguments=generate_arguments( + args=[ + generate_arg("self"), + generate_arg( + "variables_types_combined", + annotation=generate_subscript( + generate_name(DICT), + generate_tuple( + [ + generate_name("str"), + generate_name("str"), + ] + ), + ), + ), + ] + ), + body=method_body, + return_type=generate_subscript( + generate_name("List"), generate_name("VariableDefinitionNode") + ), + ) + + def create_build_operation_ast_method(self): + keywords = [ + generate_keyword( + arg="definitions", + value=generate_list( + elements=[ + generate_call( + func=generate_name("OperationDefinitionNode"), + keywords=[ + generate_keyword( + arg="operation", + value=generate_name("operation_type"), + ), + generate_keyword( + arg="name", + value=generate_call( + func=generate_name("NameNode"), + keywords=[ + generate_keyword( + arg="value", + value=generate_name( + "operation_name", + ), + ) + ], + ), + ), + generate_keyword( + arg="variable_definitions", + value=generate_name( + "variable_definitions", + ), + ), + generate_keyword( + arg="selection_set", + value=generate_call( + func=generate_name( + "SelectionSetNode", + ), + keywords=[ + generate_keyword( + arg="selections", + value=generate_name("selections"), + ) + ], + ), + ), + ], + ) + ] + ), + ) + ] + method_body = [ + generate_return( + value=generate_call( + func=generate_name("DocumentNode"), + keywords=keywords, + ) + ) + ] + return generate_method_definition( + name="_build_operation_ast", + arguments=generate_arguments( + args=[ + generate_arg("self"), + generate_arg( + "selections", + annotation=generate_subscript( + generate_name(LIST), + generate_name(SELECTION_NODE), + ), + ), + generate_arg( + "operation_type", + annotation=generate_name("OperationType"), + ), + generate_arg("operation_name", annotation=generate_name("str")), + generate_arg( + "variable_definitions", + annotation=generate_subscript( + generate_name("List"), + generate_name("VariableDefinitionNode"), + ), + ), + ] + ), + body=method_body, + return_type=generate_name("DocumentNode"), + ) + + def create_execute_custom_operation_method(self): + method_body = [ + generate_assign( + targets=["selections"], + value=generate_call( + func=generate_attribute( + value=generate_name("self"), attr="_build_selection_set" + ), + args=[generate_name("fields")], + ), + ), + ast.Assign( + targets=[generate_name("combined_variables")], + value=generate_call( + func=generate_attribute( + value=generate_name("self"), attr="_combine_variables" + ), + args=[generate_name("fields")], + ), + lineno=1, + ), + generate_assign( + targets=["variable_definitions"], + value=generate_call( + func=generate_attribute( + value=generate_name("self"), attr="_build_variable_definitions" + ), + args=[generate_name('combined_variables["types"]')], + ), + ), + generate_assign( + targets=["operation_ast"], + value=generate_call( + func=generate_attribute( + value=generate_name("self"), attr="_build_operation_ast" + ), + args=[ + generate_name("selections"), + generate_name("operation_type"), + generate_name("operation_name"), + generate_name("variable_definitions"), + ], + ), + ), + generate_assign( + targets=["response"], + value=generate_await( + value=generate_call( + func=generate_attribute( + value=generate_name("self"), + attr="execute", + ), + args=[ + generate_call( + func=generate_name("print_ast"), + args=[generate_name("operation_ast")], + ) + ], + keywords=[ + generate_keyword( + arg="variables", + value=generate_name('combined_variables["values"]'), + ), + generate_keyword( + arg="operation_name", + value=generate_name("operation_name"), + ), + ], + ) + ), + ), + generate_return( + value=generate_call( + func=generate_attribute( + value=generate_name("self"), + attr="get_data", + ), + args=[generate_name("response")], + ) + ), + ] + return generate_async_method_definition( + name="execute_custom_operation", + arguments=generate_arguments( + args=[ + generate_arg("self"), + generate_arg("*fields", annotation=generate_name("GraphQLField")), + generate_arg( + "operation_type", + annotation=generate_name( + "OperationType", + ), + ), + generate_arg("operation_name", annotation=generate_name("str")), + ] + ), + body=method_body, + return_type=generate_subscript( + generate_name(DICT), + generate_tuple([generate_name("str"), generate_name("Any")]), + ), + ) + + def create_build_selection_set(self): + return generate_method_definition( + name="_build_selection_set", + arguments=generate_arguments( + args=[ + generate_arg("self"), + generate_arg( + "fields", + annotation=generate_subscript( + generate_name("Tuple"), + generate_tuple( + [ + generate_name("GraphQLField"), + generate_name("..."), + ] + ), + ), + ), + ] + ), + body=[ + generate_return( + value=generate_list_comp( + elt=generate_call( + func=generate_attribute( + value=generate_name( + "field", + ), + attr="to_ast", + ), + args=[generate_name("idx")], + ), + generators=[ + generate_comp( + target="idx, field", + iter_="enumerate(fields)", + ) + ], + ), + ), + ], + return_type=generate_subscript( + generate_name(LIST), + generate_name(SELECTION_NODE), + ), + ) + + def add_execute_custom_operation_method(self): + self._add_import( + generate_import_from( + [ + DOCUMENT_NODE, + OPERATION_DEFINITION_NODE, + NAME_NODE, + SELECTION_SET_NODE, + PRINT_AST, + VARIABLE_DEFINITION_NODE, + VARIABLE_NODE, + NAMED_TYPE_NODE, + SELECTION_NODE, + ], + GRAPHQL_MODULE, + ) + ) + self._add_import( + generate_import_from( + [BASE_GRAPHQL_FIELD_CLASS_NAME], BASE_OPERATION_FILE_PATH.stem, level=1 + ) + ) + self._add_import(generate_import_from([DICT, TUPLE, LIST, ANY], "typing")) + + self._class_def.body.append(self.create_execute_custom_operation_method()) + self._class_def.body.append(self.create_combine_variables_method()) + self._class_def.body.append(self.create_build_variable_definitions_method()) + self._class_def.body.append(self.create_build_operation_ast_method()) + self._class_def.body.append(self.create_build_selection_set()) + + def create_custom_operation_method(self, name, operation_type): + self._add_import( + generate_import_from( + [ + OPERATION_TYPE, + ], + GRAPHQL_MODULE, + ) + ) + body_return = generate_return( + value=generate_await( + value=generate_call( + func=generate_attribute( + value=generate_name("self"), + attr="execute_custom_operation", + ), + args=[ + generate_name("*fields"), + ], + keywords=[ + generate_keyword( + arg="operation_type", + value=generate_attribute( + value=generate_name("OperationType"), + attr=operation_type, + ), + ), + generate_keyword( + arg="operation_name", value=generate_name("operation_name") + ), + ], + ) + ) + ) + async_def_query = generate_async_method_definition( + name=name, + arguments=generate_arguments( + args=[ + generate_arg("self"), + generate_arg("*fields", annotation=generate_name("GraphQLField")), + generate_arg("operation_name", annotation=generate_name("str")), + ], + ), + body=[body_return], + return_type=generate_subscript( + generate_name(DICT), + generate_tuple([generate_name("str"), generate_name("Any")]), + ), + ) + self._class_def.body.append(async_def_query) + def get_variable_names(self, arguments: ast.arguments) -> Dict[str, str]: mapped_variable_names = [ self._operation_str_variable, diff --git a/ariadne_codegen/client_generators/constants.py b/ariadne_codegen/client_generators/constants.py index fe614d27..e7339429 100644 --- a/ariadne_codegen/client_generators/constants.py +++ b/ariadne_codegen/client_generators/constants.py @@ -16,17 +16,39 @@ LIST = "List" UNION = "Union" ANY = "Any" +TYPE = "Type" +TYPE_CHECKING = "TYPE_CHECKING" DICT = "Dict" +TUPLE = "Tuple" CALLABLE = "Callable" ANNOTATED = "Annotated" LITERAL = "Literal" ASYNC_ITERATOR = "AsyncIterator" +DOCUMENT_NODE = "DocumentNode" +OPERATION_DEFINITION_NODE = "OperationDefinitionNode" +NAME_NODE = "NameNode" +SELECTION_SET_NODE = "SelectionSetNode" +SELECTION_NODE = "SelectionNode" +PRINT_AST = "print_ast" +OPERATION_TYPE = "OperationType" +VARIABLE_DEFINITION_NODE = "VariableDefinitionNode" +VARIABLE_NODE = "VariableNode" +NAMED_TYPE_NODE = "NamedTypeNode" + +HTTPX = "httpx" +HTTPX_RESPONSE = "httpx.Response" TIMESTAMP_COMMENT = "# Generated by ariadne-codegen on {}" STABLE_COMMENT = "# Generated by ariadne-codegen" SOURCE_COMMENT = "# Source: {}" COMMENT_DATETIME_FORMAT = "%Y-%m-%d %H:%M" +BASE_OPERATION_FILE_PATH = Path(__file__).parent / "dependencies" / "base_operation.py" +BASE_GRAPHQL_OPERATION_CLASS_NAME = "BaseGraphQLOperation" +BASE_GRAPHQL_FIELD_CLASS_NAME = "GraphQLField" +CUSTOM_FIELDS_FILE_PATH = Path(__file__).parent / "custom_fields.py" +CUSTOM_FIELDS_TYPING_FILE_PATH = Path(__file__).parent / "custom_typing_fields.py" + BASE_MODEL_FILE_PATH = Path(__file__).parent / "dependencies" / "base_model.py" BASE_MODEL_CLASS_NAME = "BaseModel" BASE_MODEL_IMPORT = ast.ImportFrom( @@ -49,6 +71,7 @@ TYPENAME_ALIAS = "typename__" TYPING_MODULE = "typing" +GRAPHQL_MODULE = "graphql" PYDANTIC_MODULE = "pydantic" FIELD_CLASS = "Field" ALIAS_KEYWORD = "alias" @@ -100,3 +123,11 @@ SCALARS_PARSE_DICT_NAME = "SCALARS_PARSE_FUNCTIONS" SCALARS_SERIALIZE_DICT_NAME = "SCALARS_SERIALIZE_FUNCTIONS" + +OPERATION_TYPES = ("Query", "Mutation", "Subscription") + +GRAPHQL_OBJECT_SUFFIX = "Fields" +GRAPHQL_INTERFACE_SUFFIX = "Interface" +GRAPHQL_FIELD_SUFFIX = "GraphQLField" +GRAPHQL_UNION_SUFFIX = "Union" +GRAPHQL_BASE_FIELD_CLASS = "GraphQLField" diff --git a/ariadne_codegen/client_generators/custom_arguments.py b/ariadne_codegen/client_generators/custom_arguments.py new file mode 100644 index 00000000..cd0700e4 --- /dev/null +++ b/ariadne_codegen/client_generators/custom_arguments.py @@ -0,0 +1,277 @@ +import ast +from typing import Any, Dict, List, Optional, Tuple, Union, cast + +from graphql import ( + GraphQLEnumType, + GraphQLInputObjectType, + GraphQLInterfaceType, + GraphQLNonNull, + GraphQLObjectType, + GraphQLScalarType, + GraphQLUnionType, +) + +from ..codegen import ( + generate_ann_assign, + generate_annotation_name, + generate_arg, + generate_arguments, + generate_assign, + generate_call, + generate_comp, + generate_constant, + generate_dict, + generate_import_from, + generate_keyword, + generate_name, + generate_subscript, + generate_tuple, +) +from ..exceptions import ParsingError +from ..plugins.manager import PluginManager +from ..utils import process_name +from .constants import ( + ANY, + BASE_MODEL_FILE_PATH, + DICT, + INPUT_SCALARS_MAP, + UPLOAD_CLASS_NAME, +) +from .custom_generator_utils import get_final_type +from .scalars import ScalarData, generate_scalar_imports + + +class ArgumentGenerator: + """Generates method arguments for GraphQL fields.""" + + def __init__( + self, + custom_scalars: Dict[str, ScalarData], + convert_to_snake_case: bool, + plugin_manager: Optional[PluginManager] = None, + ) -> None: + self.custom_scalars = custom_scalars + self.convert_to_snake_case = convert_to_snake_case + self.plugin_manager = plugin_manager + self.imports: List[ast.ImportFrom] = [] + self._used_custom_scalars: List[str] = [] + + def _add_import(self, import_: Optional[ast.ImportFrom] = None) -> None: + """Adds an import statement to the list of imports.""" + if import_: + if self.plugin_manager: + import_ = self.plugin_manager.generate_client_import(import_) + if import_.names: + self.imports.append(import_) + + def generate_arguments( + self, operation_args: Dict[str, Any] + ) -> Tuple[ast.arguments, List[ast.expr], List[ast.expr]]: + """Generates method arguments from operation arguments.""" + cls_arg = generate_arg(name="cls") + args: List[ast.arg] = [] + kw_only_args: List[ast.arg] = [] + kw_defaults: List[ast.expr] = [] + return_arguments_keys: List[ast.expr] = [] + return_arguments_values: List[ast.expr] = [] + + for arg_name, arg_value in operation_args.items(): + final_type = get_final_type(arg_value) + is_required = isinstance(arg_value.type, GraphQLNonNull) + name = process_name( + arg_name, convert_to_snake_case=self.convert_to_snake_case + ) + annotation, used_custom_scalar = self._parse_graphql_type_name( + final_type, not is_required + ) + + self._accumulate_method_arguments( + args, kw_only_args, kw_defaults, name, annotation, is_required + ) + self._accumulate_return_arguments( + return_arguments_keys, + return_arguments_values, + arg_name, + name, + final_type, + is_required, + used_custom_scalar, + ) + + method_arguments = self._assemble_method_arguments( + cls_arg, args, kw_only_args, kw_defaults + ) + return method_arguments, return_arguments_keys, return_arguments_values + + def _accumulate_method_arguments( + self, + args: List[ast.arg], + kw_only_args: List[ast.arg], + kw_defaults: List[ast.expr], + name: str, + annotation: Optional[Union[ast.Name, ast.Subscript]], + is_required: bool, + ) -> None: + """Accumulates method arguments.""" + if is_required: + args.append(generate_arg(name=name, annotation=annotation)) + else: + kw_only_args.append(generate_arg(name=name, annotation=annotation)) + kw_defaults.append(generate_constant(value=None)) + + def _accumulate_return_arguments( + self, + return_arguments_keys: List[ast.expr], + return_arguments_values: List[ast.expr], + arg_name: str, + name: str, + final_type: Union[GraphQLObjectType, GraphQLInterfaceType, GraphQLUnionType], + is_required: bool, + used_custom_scalar: Optional[str], + ) -> None: + """Accumulates return arguments.""" + constant_value = f"{final_type.name}!" if is_required else final_type.name + return_arg_dict_value = self._generate_return_arg_value( + name, used_custom_scalar + ) + + return_arguments_keys.append(generate_constant(arg_name)) + return_arguments_values.append( + generate_dict( + keys=[generate_constant("type"), generate_constant("value")], + values=[generate_constant(constant_value), return_arg_dict_value], + ) + ) + + def _generate_return_arg_value( + self, name: str, used_custom_scalar: Optional[str] + ) -> Union[ast.Call, ast.Name]: + """Generates the return argument value.""" + return_arg_dict_value: Union[ast.Call, ast.Name] = generate_name(name) + + if used_custom_scalar: + self._used_custom_scalars.append(used_custom_scalar) + scalar_data = self.custom_scalars[used_custom_scalar] + if scalar_data.serialize_name: + return_arg_dict_value = generate_call( + func=generate_name(scalar_data.serialize_name), + args=[generate_name(name)], + ) + + return return_arg_dict_value + + def _assemble_method_arguments( + self, + cls_arg: ast.arg, + args: List[ast.arg], + kw_only_args: List[ast.arg], + kw_defaults: List[ast.expr], + ) -> ast.arguments: + """Assembles method arguments.""" + return generate_arguments( + args=[cls_arg, *args], + kwonlyargs=kw_only_args, + kw_defaults=kw_defaults, # type: ignore + ) + + def _parse_graphql_type_name( + self, + type_: Union[GraphQLScalarType, GraphQLInputObjectType, GraphQLEnumType], + nullable: bool = True, + ) -> Tuple[Union[ast.Name, ast.Subscript], Optional[str]]: + """Parses the GraphQL type name and determines if it is a custom scalar.""" + name = type_.name + used_custom_scalar = None + if isinstance(type_, GraphQLInputObjectType): + self._add_import( + generate_import_from(names=[name], from_="input_types", level=1) + ) + elif isinstance(type_, GraphQLEnumType): + self._add_import(generate_import_from(names=[name], level=1)) + elif isinstance(type_, GraphQLScalarType): + if name not in self.custom_scalars: + name = INPUT_SCALARS_MAP.get(name, ANY) + if name == UPLOAD_CLASS_NAME: + self._add_import( + generate_import_from( + names=[UPLOAD_CLASS_NAME], + from_=BASE_MODEL_FILE_PATH.stem, + level=1, + ) + ) + else: + used_custom_scalar = name + name = self.custom_scalars[name].type_name + self._used_custom_scalars.append(used_custom_scalar) + else: + raise ParsingError(f"Incorrect argument type {name}") + return generate_annotation_name(name, nullable), used_custom_scalar + + def add_custom_scalar_imports(self) -> None: + """Adds imports for custom scalars used in the schema.""" + for custom_scalar_name in self._used_custom_scalars: + scalar_data = self.custom_scalars[custom_scalar_name] + for import_ in generate_scalar_imports(scalar_data): + self._add_import(import_) + + def generate_clear_arguments_section( + self, + return_arguments_keys: List[ast.expr], + return_arguments_values: List[ast.expr], + ) -> Tuple[List[ast.stmt], List[ast.keyword]]: + arguments_body = [ + generate_ann_assign( + "arguments", + generate_subscript( + generate_name(DICT), + generate_tuple( + [ + generate_name("str"), + generate_subscript( + generate_name(DICT), + generate_tuple( + [ + generate_name("str"), + generate_name(ANY), + ] + ), + ), + ] + ), + ), + generate_dict( + return_arguments_keys, + return_arguments_values, # type: ignore + ), + ), + generate_assign( + ["cleared_arguments"], + ast.DictComp( + key=generate_name("key"), + value=generate_name("value"), + generators=[ + generate_comp( + target="key, value", + iter_="arguments.items()", + ifs=cast( + List[ast.expr], + [ + ast.Compare( + left=generate_subscript( + value=generate_name("value"), + slice_=generate_constant("value"), + ), + ops=[ast.IsNot()], + comparators=[generate_constant(None)], + ) + ], + ), + ) + ], + ), + ), + ] + arguments_keyword = [ + generate_keyword(arg="arguments", value=generate_name("cleared_arguments")) + ] + return arguments_body, arguments_keyword diff --git a/ariadne_codegen/client_generators/custom_fields.py b/ariadne_codegen/client_generators/custom_fields.py new file mode 100644 index 00000000..0c0ba369 --- /dev/null +++ b/ariadne_codegen/client_generators/custom_fields.py @@ -0,0 +1,389 @@ +import ast +from typing import Any, Dict, List, Optional, Tuple, Union, cast + +from graphql import ( + GraphQLInterfaceType, + GraphQLNamedType, + GraphQLObjectType, + GraphQLSchema, + GraphQLUnionType, +) + +from ariadne_codegen.client_generators.custom_arguments import ArgumentGenerator + +from ..codegen import ( + generate_ann_assign, + generate_arg, + generate_arguments, + generate_attribute, + generate_call, + generate_class_def, + generate_constant, + generate_expr, + generate_import_from, + generate_method_definition, + generate_module, + generate_name, + generate_return, + generate_subscript, + generate_union_annotation, +) +from ..plugins.manager import PluginManager +from ..utils import process_name +from .constants import ( + ANY, + BASE_GRAPHQL_FIELD_CLASS_NAME, + BASE_OPERATION_FILE_PATH, + DICT, + GRAPHQL_BASE_FIELD_CLASS, + GRAPHQL_INTERFACE_SUFFIX, + GRAPHQL_OBJECT_SUFFIX, + GRAPHQL_UNION_SUFFIX, + OPTIONAL, + TYPING_MODULE, + UNION, +) +from .custom_generator_utils import TypeCollector, get_final_type +from .scalars import ScalarData + + +class CustomFieldsGenerator: + """Generates custom fields for a given GraphQL schema using Python's AST module.""" + + def __init__( + self, + schema: GraphQLSchema, + convert_to_snake_case: bool = True, + custom_scalars: Optional[Dict[str, ScalarData]] = None, + plugin_manager: Optional[PluginManager] = None, + ) -> None: + self.schema = schema + self.convert_to_snake_case = convert_to_snake_case + self.plugin_manager = plugin_manager + self.custom_scalars = custom_scalars if custom_scalars else {} + self._imports: List[ast.ImportFrom] = [ + ast.ImportFrom( + module=BASE_OPERATION_FILE_PATH.stem, + names=[ast.alias(BASE_GRAPHQL_FIELD_CLASS_NAME)], + level=1, + ) + ] + self._add_import( + generate_import_from( + [OPTIONAL, UNION, ANY, DICT], + TYPING_MODULE, + ) + ) + self.argument_generator = ArgumentGenerator( + self.custom_scalars, + self.convert_to_snake_case, + self.plugin_manager, + ) + self._class_defs: List[ast.ClassDef] = self._parse_object_type_definitions( + TypeCollector(self.schema).collect() + ) + + def generate(self) -> ast.Module: + """Generates an AST module containing the custom fields and required imports.""" + self.argument_generator.add_custom_scalar_imports() + module = generate_module( + body=cast(List[ast.stmt], self._imports + self._class_defs), + ) + return module + + def _add_import(self, import_: Optional[ast.ImportFrom] = None) -> None: + """Adds an import statement to the list of imports.""" + if import_: + if self.plugin_manager: + import_ = self.plugin_manager.generate_client_import(import_) + if import_.names: + self._imports.append(import_) + + def _parse_object_type_definitions( + self, type_names: List[str] + ) -> List[ast.ClassDef]: + """ + Parses object type definitions from the schema + and generates AST class definitions. + """ + class_defs = [] + + for type_name in type_names: + graphql_type = self.schema.get_type(type_name) + if isinstance(graphql_type, (GraphQLObjectType, GraphQLInterfaceType)): + class_def = self._generate_class_def_body( + definition=graphql_type, + class_name=f"{graphql_type.name}{self._get_suffix(graphql_type)}", + ) + if isinstance(graphql_type, GraphQLInterfaceType): + class_def.body.append( + self._generate_on_method( + f"{graphql_type.name}{GRAPHQL_INTERFACE_SUFFIX}" + ) + ) + class_defs.append(class_def) + + return class_defs + + def _generate_class_def_body( + self, + definition: Union[GraphQLObjectType, GraphQLInterfaceType], + class_name: str, + ) -> ast.ClassDef: + """ + Generates the body of a class definition for a given GraphQL object + or interface type. + """ + base_names = [GRAPHQL_BASE_FIELD_CLASS] + additional_fields_typing = set() + class_def = generate_class_def(name=class_name, base_names=base_names) + for lineno, (org_name, field) in enumerate( + self._get_combined_fields(definition).items(), start=1 + ): + name = process_name( + org_name, convert_to_snake_case=self.convert_to_snake_case + ) + final_type = get_final_type(field) + field_name, method_required = self._get_field_name( + final_type, definition.name + ) + if self._is_custom_type(final_type): + additional_fields_typing.add(field_name) + class_def.body.append( + self._generate_class_field( + name, field_name, org_name, field, method_required, lineno + ) + ) + + class_def.body.append( + self._generate_fields_method( + class_name, definition.name, sorted(additional_fields_typing) + ) + ) + class_def.body.append(self._generate_alias_method(class_name)) + return class_def + + def _get_combined_fields( + self, definition: Union[GraphQLObjectType, GraphQLInterfaceType] + ) -> Dict[str, ast.ClassDef]: + """Combines fields from the definition and its interfaces.""" + fields = dict(definition.fields.items()) + for interface in getattr(definition, "interfaces", []): + fields.update(dict(interface.fields.items())) + return fields + + def _get_field_name( + self, final_type: GraphQLNamedType, definition_name: str + ) -> Tuple[str, bool]: + """ + Returns the appropriate field name suffix based on the type of GraphQL type. + """ + if isinstance(final_type, GraphQLObjectType): + return f"{final_type.name}{GRAPHQL_OBJECT_SUFFIX}", True + if isinstance(final_type, GraphQLInterfaceType): + return f"{final_type.name}{GRAPHQL_INTERFACE_SUFFIX}", True + if isinstance(final_type, GraphQLUnionType): + field_name = f"{final_type.name}{GRAPHQL_UNION_SUFFIX}" + else: + field_name = f"{definition_name}{GRAPHQL_BASE_FIELD_CLASS}" + self._add_import( + generate_import_from( + [field_name], + from_="custom_typing_fields", + level=1, + ) + ) + return field_name, False + + def _is_custom_type( + self, + final_type: Union[GraphQLObjectType, GraphQLInterfaceType, GraphQLUnionType], + ) -> bool: + """Checks if the final type is a custom type (Object, Interface, or Union).""" + return isinstance( + final_type, (GraphQLObjectType, GraphQLInterfaceType, GraphQLUnionType) + ) + + def _generate_class_field( + self, + name: str, + field_name: str, + org_name: str, + field: ast.ClassDef, + method_required: bool, + lineno: int, + ) -> ast.stmt: + """Handles the generation of field types.""" + if getattr(field, "args") or method_required: + return self.generate_product_type_method( + name, field_name, getattr(field, "args") + ) + return generate_ann_assign( + target=name, + annotation=generate_name(f'"{field_name}"'), + value=generate_call( + func=generate_name(field_name), args=[generate_constant(org_name)] + ), + lineno=lineno, + ) + + def _generate_fields_method( + self, class_name: str, definition_name: str, additional_fields_typing: List[str] + ) -> ast.FunctionDef: + """Generates the `fields` method for a class.""" + field_class_name = generate_name(f"{definition_name}{GRAPHQL_BASE_FIELD_CLASS}") + self._add_import( + generate_import_from( + [field_class_name.id], from_="custom_typing_fields", level=1 + ) + ) + fields_annotation: Union[ast.Name, ast.Subscript] = field_class_name + if additional_fields_typing: + additional_fields_typing_ann = [ + generate_name(f'"{field_typing}"') + for field_typing in additional_fields_typing + ] + fields_annotation = generate_union_annotation( + [field_class_name, *additional_fields_typing_ann], nullable=False + ) + + return generate_method_definition( + "fields", + arguments=generate_arguments( + [ + generate_arg(name="self"), + generate_arg(name="*subfields", annotation=fields_annotation), + ] + ), + body=[ + generate_expr( + value=generate_constant( + value=f"Subfields should come from the {class_name} class" + ) + ), + generate_expr( + value=generate_call( + func=generate_attribute( + value=generate_name("self"), attr="_subfields.extend" + ), + args=[generate_name("subfields")], + ) + ), + generate_return(value=generate_name("self")), + ], + return_type=generate_name(f'"{class_name}"'), + ) + + def _generate_on_method(self, class_name: str) -> ast.FunctionDef: + """Generates the `on` method for a class.""" + return generate_method_definition( + "on", + arguments=generate_arguments( + [ + generate_arg(name="self"), + generate_arg(name="type_name", annotation=generate_name("str")), + generate_arg( + name="*subfields", + annotation=generate_name(GRAPHQL_BASE_FIELD_CLASS), + ), + ] + ), + body=cast( + List[ast.stmt], + [ + ast.Assign( + targets=[ + generate_subscript( + value=generate_attribute( + value=generate_name("self"), + attr="_inline_fragments", + ), + slice_=generate_name("type_name"), + ) + ], + value=generate_name("subfields"), + lineno=1, + ), + generate_return(value=generate_name("self")), + ], + ), + return_type=generate_name(f'"{class_name}"'), + ) + + def generate_product_type_method( + self, name: str, class_name: str, arguments: Optional[Dict[str, Any]] = None + ) -> ast.FunctionDef: + """Generates a method for a product type.""" + arguments = arguments or {} + field_class_name = generate_name(class_name) + ( + method_arguments, + return_arguments_keys, + return_arguments_values, + ) = self.argument_generator.generate_arguments(arguments) + self._imports.extend(self.argument_generator.imports) + arguments_body: List[ast.stmt] = [] + arguments_keyword: List[ast.keyword] = [] + + if arguments: + ( + arguments_body, + arguments_keyword, + ) = self.argument_generator.generate_clear_arguments_section( + return_arguments_keys, return_arguments_values + ) + + return generate_method_definition( + name, + arguments=method_arguments, + body=cast( + List[ast.stmt], + [ + *arguments_body, + generate_return( + value=generate_call( + func=field_class_name, + args=[generate_constant(name)], + keywords=arguments_keyword, + ) + ), + ], + ), + return_type=generate_name(f'"{class_name}"'), + decorator_list=[generate_name("classmethod")], + ) + + def _get_suffix( + self, graphql_type: Union[GraphQLObjectType, GraphQLInterfaceType] + ) -> str: + """Gets the appropriate suffix for a GraphQL type.""" + if isinstance(graphql_type, GraphQLObjectType): + return GRAPHQL_OBJECT_SUFFIX + if isinstance(graphql_type, GraphQLInterfaceType): + return GRAPHQL_INTERFACE_SUFFIX + raise ValueError(f"Unexpected graphql_type: {graphql_type}") + + def _generate_alias_method(self, class_name: str) -> ast.FunctionDef: + """ + Generates the `alias` method for a class. + """ + return generate_method_definition( + "alias", + arguments=generate_arguments( + [ + generate_arg(name="self"), + generate_arg(name="alias", annotation=generate_name("str")), + ] + ), + body=[ + ast.Assign( + targets=[ + generate_attribute(value=generate_name("self"), attr="_alias"), + ], + value=generate_name("alias"), + lineno=1, + ), + generate_return(value=generate_name("self")), + ], + return_type=generate_name(f'"{class_name}"'), + ) diff --git a/ariadne_codegen/client_generators/custom_fields_typing.py b/ariadne_codegen/client_generators/custom_fields_typing.py new file mode 100644 index 00000000..54adb933 --- /dev/null +++ b/ariadne_codegen/client_generators/custom_fields_typing.py @@ -0,0 +1,155 @@ +import ast +from typing import List, cast + +from graphql import ( + GraphQLInterfaceType, + GraphQLObjectType, + GraphQLSchema, + GraphQLUnionType, +) + +from ariadne_codegen.client_generators.custom_generator_utils import get_final_type + +from ..codegen import ( + generate_arg, + generate_arguments, + generate_attribute, + generate_class_def, + generate_method_definition, + generate_module, + generate_name, + generate_return, + generate_subscript, +) +from .constants import ( + BASE_OPERATION_FILE_PATH, + GRAPHQL_BASE_FIELD_CLASS, + OPERATION_TYPES, +) + + +class CustomFieldsTypingGenerator: + def __init__(self, schema: GraphQLSchema) -> None: + self.schema = schema + self.graphql_field_import = ast.ImportFrom( + module=BASE_OPERATION_FILE_PATH.stem, + names=[ast.alias(GRAPHQL_BASE_FIELD_CLASS)], + level=1, + ) + self._public_names: List[str] = [] + self._class_defs: List[ast.ClassDef] = [ + self._generate_field_class(d) for d in self._filter_types() + ] + + def generate(self) -> ast.Module: + """ + Generates an AST module containing the custom fields and required imports. + """ + return generate_module( + body=cast(List[ast.stmt], [self.graphql_field_import]) + + cast(List[ast.stmt], self._class_defs) + ) + + def _filter_types(self) -> List[ast.ClassDef]: + """ + Filters GraphQL types to include only objects, interfaces, and unions, + excluding internal and operation types. + """ + return [ + get_final_type(definition) + for name, definition in self.schema.type_map.items() + if isinstance( + definition, (GraphQLObjectType, GraphQLInterfaceType, GraphQLUnionType) + ) + and not name.startswith("__") + and name not in OPERATION_TYPES + ] + + def _generate_field_class( + self, + graphql_type: ast.ClassDef, + ) -> ast.ClassDef: + """ + Generates a field class for the given GraphQL type. + """ + class_name = f"{graphql_type.name}{GRAPHQL_BASE_FIELD_CLASS}" + class_body: List[ast.stmt] = [] + + if isinstance(graphql_type, GraphQLUnionType): + class_name = f"{graphql_type.name}Union" + class_body.append(self._generate_on_method(class_name)) + class_body.append(self._generate_alias_method(class_name)) + if class_name not in self._public_names: + self._public_names.append(class_name) + + field_class_def = generate_class_def( + name=class_name, + base_names=[GRAPHQL_BASE_FIELD_CLASS], + body=class_body if class_body else cast(List[ast.stmt], [ast.Pass()]), + ) + return field_class_def + + def _generate_on_method(self, class_name: str) -> ast.FunctionDef: + """ + Generates the `on` method for a class. + """ + return generate_method_definition( + "on", + arguments=generate_arguments( + [ + generate_arg(name="self"), + generate_arg(name="type_name", annotation=generate_name("str")), + generate_arg( + name="*subfields", + annotation=generate_name(GRAPHQL_BASE_FIELD_CLASS), + ), + ] + ), + body=[ + ast.Assign( + targets=[ + generate_subscript( + value=generate_attribute( + value=generate_name("self"), attr="_inline_fragments" + ), + slice_=generate_name("type_name"), + ) + ], + value=generate_name("subfields"), + lineno=1, + ), + generate_return(value=generate_name("self")), + ], + return_type=generate_name(f'"{class_name}"'), + ) + + def _generate_alias_method(self, class_name: str) -> ast.FunctionDef: + """ + Generates the `alias` method for a class. + """ + return generate_method_definition( + "alias", + arguments=generate_arguments( + [ + generate_arg(name="self"), + generate_arg(name="alias", annotation=generate_name("str")), + ] + ), + body=[ + ast.Assign( + targets=[ + generate_attribute(value=generate_name("self"), attr="_alias"), + ], + value=generate_name("alias"), + lineno=1, + ), + generate_return(value=generate_name("self")), + ], + return_type=generate_name(f'"{class_name}"'), + ) + + def get_generated_public_names(self) -> List[str]: + """ + Returns the list of generated public names. + """ + return self._public_names diff --git a/ariadne_codegen/client_generators/custom_generator_utils.py b/ariadne_codegen/client_generators/custom_generator_utils.py new file mode 100644 index 00000000..81b59402 --- /dev/null +++ b/ariadne_codegen/client_generators/custom_generator_utils.py @@ -0,0 +1,58 @@ +from typing import Dict, List, Set + +from graphql import ( + GraphQLInterfaceType, + GraphQLObjectType, + GraphQLSchema, + GraphQLUnionType, +) + + +class TypeCollector: + def __init__(self, schema: GraphQLSchema): + self.schema = schema + self.collected_types: Set[str] = set() + self.visited_types: Set[str] = set() + + def collect(self) -> List[str]: + if self.schema.query_type: + self._collect_types(self.schema.query_type.fields) + if self.schema.mutation_type: + self._collect_types(self.schema.mutation_type.fields) + return sorted(self.collected_types) + + def _collect_types(self, fields: Dict[str, GraphQLObjectType]) -> None: + for field in fields.values(): + graphql_type = get_final_type(field) + self._collect_dependent_types(graphql_type) + + def _collect_dependent_types(self, graphql_type: GraphQLObjectType) -> None: + stack = [graphql_type] + + while stack: + current_type = stack.pop() + if current_type.name in self.visited_types: + continue + + self.visited_types.add(current_type.name) + self.collected_types.add(current_type.name) + + if isinstance(current_type, (GraphQLObjectType, GraphQLInterfaceType)): + for subfield in current_type.fields.values(): + subfield_type = get_final_type(subfield) + if isinstance(subfield_type, GraphQLObjectType): + stack.append(subfield_type) + elif isinstance(subfield_type, GraphQLUnionType): + stack.extend(subfield_type.types) + for interface in current_type.interfaces: + stack.append(interface) + elif isinstance(current_type, GraphQLUnionType): + stack.extend(current_type.types) + + +def get_final_type(type_): + while hasattr(type_, "of_type"): + type_ = type_.of_type + if hasattr(type_, "type"): + return get_final_type(type_.type) + return type_ diff --git a/ariadne_codegen/client_generators/custom_operation.py b/ariadne_codegen/client_generators/custom_operation.py new file mode 100644 index 00000000..f9ee64f6 --- /dev/null +++ b/ariadne_codegen/client_generators/custom_operation.py @@ -0,0 +1,191 @@ +import ast +from typing import Dict, List, Optional, cast + +from graphql import ( + GraphQLFieldMap, + GraphQLInterfaceType, + GraphQLObjectType, + GraphQLUnionType, +) + +from ariadne_codegen.client_generators.custom_arguments import ArgumentGenerator + +from ..codegen import ( + generate_call, + generate_class_def, + generate_constant, + generate_import_from, + generate_keyword, + generate_method_definition, + generate_module, + generate_name, + generate_return, +) +from ..plugins.manager import PluginManager +from ..utils import str_to_snake_case +from .arguments import ArgumentsGenerator +from .constants import ( + ANY, + CUSTOM_FIELDS_FILE_PATH, + CUSTOM_FIELDS_TYPING_FILE_PATH, + DICT, + GRAPHQL_BASE_FIELD_CLASS, + GRAPHQL_INTERFACE_SUFFIX, + GRAPHQL_OBJECT_SUFFIX, + GRAPHQL_UNION_SUFFIX, + OPTIONAL, + TYPING_MODULE, +) +from .custom_generator_utils import get_final_type +from .scalars import ScalarData + + +class CustomOperationGenerator: + """ + Generates custom operations for a given GraphQL schema using Python's AST module. + """ + + def __init__( + self, + graphql_fields: GraphQLFieldMap, + name: str, + base_name: str, + arguments_generator: ArgumentsGenerator, + enums_module_name: str = "enums", + custom_scalars: Optional[Dict[str, ScalarData]] = None, + plugin_manager: Optional[PluginManager] = None, + convert_to_snake_case: bool = True, + ) -> None: + self.graphql_fields = graphql_fields + self.name = name + self.base_name = base_name + self.enums_module_name = enums_module_name + self.plugin_manager = plugin_manager + self.custom_scalars = custom_scalars if custom_scalars else {} + self.arguments_generator = arguments_generator + self.convert_to_snake_case = convert_to_snake_case + + self._imports: List[ast.ImportFrom] = [] + self._type_imports: List[ast.ImportFrom] = [] + self._add_import(generate_import_from([OPTIONAL, ANY, DICT], TYPING_MODULE)) + self.argument_generator = ArgumentGenerator( + self.custom_scalars, + self.convert_to_snake_case, + self.plugin_manager, + ) + + self._class_def = generate_class_def(name=name, base_names=[]) + + def generate(self) -> ast.Module: + """Generate module with class definition of graphql client.""" + for name, field in self.graphql_fields.items(): + final_type = get_final_type(field) + method_def = self._generate_method( + operation_name=name, + operation_args=field.args, + final_type=final_type, + ) + method_def.lineno = len(self._class_def.body) + 1 + self._class_def.body.append(method_def) + + if not self._class_def.body: + self._class_def.body.append(ast.Pass()) + + self.argument_generator.add_custom_scalar_imports() + + self._class_def.lineno = len(self._imports) + 3 + + module = generate_module( + body=cast(List[ast.stmt], self._imports) + + cast(List[ast.stmt], self._type_imports) + + [self._class_def], + ) + return module + + def _add_import(self, import_: Optional[ast.ImportFrom] = None): + """Adds an import statement to the list of imports.""" + if import_: + if self.plugin_manager: + import_ = self.plugin_manager.generate_client_import(import_) + if import_.names and import_.module: + self._imports.append(import_) + + def _generate_method( + self, + operation_name: str, + operation_args, + final_type, + ) -> ast.FunctionDef: + """Generates a method definition for a given operation.""" + ( + method_arguments, + return_arguments_keys, + return_arguments_values, + ) = self.argument_generator.generate_arguments(operation_args) + self._imports.extend(self.argument_generator.imports) + + return_type_name = self._get_return_type_and_from(final_type) + + arguments_body: List[ast.stmt] = [] + arguments_keyword: List[ast.keyword] = [] + + if operation_args: + ( + arguments_body, + arguments_keyword, + ) = self.argument_generator.generate_clear_arguments_section( + return_arguments_keys, return_arguments_values + ) + + return generate_method_definition( + name=str_to_snake_case(operation_name), + arguments=method_arguments, + return_type=generate_name(return_type_name), + body=[ + *arguments_body, + generate_return( + value=generate_call( + func=generate_name(return_type_name), + args=[], + keywords=[ + generate_keyword( + arg="field_name", + value=generate_constant(value=operation_name), + ), + *arguments_keyword, + ], + ) + ), + ], + decorator_list=[generate_name("classmethod")], + ) + + def _get_return_type_and_from(self, final_type): + """ + Determines the return type name and its import path based on the final type. + """ + if isinstance(final_type, GraphQLObjectType): + return_type_name = f"{final_type.name}{GRAPHQL_OBJECT_SUFFIX}" + from_ = CUSTOM_FIELDS_FILE_PATH.stem + elif isinstance(final_type, GraphQLInterfaceType): + return_type_name = f"{final_type.name}{GRAPHQL_INTERFACE_SUFFIX}" + from_ = CUSTOM_FIELDS_FILE_PATH.stem + elif isinstance(final_type, GraphQLUnionType): + return_type_name = f"{final_type.name}{GRAPHQL_UNION_SUFFIX}" + from_ = CUSTOM_FIELDS_TYPING_FILE_PATH.stem + else: + return_type_name = GRAPHQL_BASE_FIELD_CLASS + from_ = CUSTOM_FIELDS_TYPING_FILE_PATH.stem + self._type_imports.append( + generate_import_from( + from_=from_, + names=[return_type_name], + level=1, + ) + ) + return return_type_name + + @staticmethod + def _capitalize_first_letter(s: str) -> str: + """Capitalizes the first letter of the given string.""" + return s[0].upper() + s[1:] diff --git a/ariadne_codegen/client_generators/dependencies/base_operation.py b/ariadne_codegen/client_generators/dependencies/base_operation.py new file mode 100644 index 00000000..9f9c7660 --- /dev/null +++ b/ariadne_codegen/client_generators/dependencies/base_operation.py @@ -0,0 +1,156 @@ +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +from graphql import ( + ArgumentNode, + FieldNode, + InlineFragmentNode, + NamedTypeNode, + NameNode, + SelectionSetNode, + VariableNode, +) + + +class GraphQLArgument: + """ + Represents a GraphQL argument and allows conversion to an AST structure. + """ + + def __init__(self, argument_name: str, argument_value: Any) -> None: + self._name = argument_name + self._value = argument_value + + def to_ast(self) -> ArgumentNode: + """Converts the argument to an ArgumentNode AST object.""" + return ArgumentNode( + name=NameNode(value=self._name), + value=VariableNode(name=NameNode(value=self._value)), + ) + + +class GraphQLField: + """ + Represents a GraphQL field with its name, arguments, subfields, alias, + and inline fragments. + + Attributes: + formatted_variables (Dict[str, Dict[str, Any]]): The formatted arguments + of the GraphQL field. + """ + + def __init__( + self, field_name: str, arguments: Optional[Dict[str, Dict[str, Any]]] = None + ) -> None: + self._field_name = field_name + self._variables = arguments or {} + self.formatted_variables: Dict[str, Dict[str, Any]] = {} + self._subfields: List[GraphQLField] = [] + self._alias: Optional[str] = None + self._inline_fragments: Dict[str, Tuple[GraphQLField, ...]] = {} + + def alias(self, alias: str) -> "GraphQLField": + """Sets an alias for the GraphQL field and returns the instance.""" + self._alias = alias + return self + + def _build_field_name(self) -> str: + """Builds the field name, including the alias if present.""" + return f"{self._alias}: {self._field_name}" if self._alias else self._field_name + + def _build_selections( + self, idx: int, used_names: Set[str] + ) -> List[Union[FieldNode, InlineFragmentNode]]: + """Builds the selection set for the current GraphQL field, + including subfields and inline fragments.""" + # Create selections from subfields + selections: List[Union[FieldNode, InlineFragmentNode]] = [ + subfield.to_ast(idx, used_names) for subfield in self._subfields + ] + + # Add inline fragments + for name, subfields in self._inline_fragments.items(): + selections.append( + InlineFragmentNode( + type_condition=NamedTypeNode(name=NameNode(value=name)), + selection_set=SelectionSetNode( + selections=[ + subfield.to_ast(idx, used_names) for subfield in subfields + ] + ), + ) + ) + + return selections + + def _format_variable_name( + self, idx: int, var_name: str, used_names: Set[str] + ) -> str: + """Generates a unique variable name by appending an index and, + if necessary, an additional counter to avoid duplicates.""" + base_name = f"{var_name}_{idx}" + unique_name = base_name + counter = 1 + + # Ensure the generated name is unique + while unique_name in used_names: + unique_name = f"{base_name}_{counter}" + counter += 1 + + # Add the unique name to the set of used names + used_names.add(unique_name) + + return unique_name + + def _collect_all_variables(self, idx: int, used_names: Set[str]) -> None: + """ + Collects and formats all variables for the current GraphQL field, + ensuring unique names. + """ + self.formatted_variables = {} + + for k, v in self._variables.items(): + unique_name = self._format_variable_name(idx, k, used_names) + self.formatted_variables[unique_name] = { + "name": k, + "type": v["type"], + "value": v["value"], + } + + def to_ast(self, idx: int, used_names: Optional[Set[str]] = None) -> FieldNode: + """Converts the current GraphQL field to an AST (Abstract Syntax Tree) node.""" + if used_names is None: + used_names = set() + + self._collect_all_variables(idx, used_names) + + return FieldNode( + name=NameNode(value=self._build_field_name()), + arguments=[ + GraphQLArgument(v["name"], k).to_ast() + for k, v in self.formatted_variables.items() + ], + selection_set=( + SelectionSetNode(selections=self._build_selections(idx, used_names)) + if self._subfields or self._inline_fragments + else None + ), + ) + + def get_formatted_variables(self) -> Dict[str, Dict[str, Any]]: + """ + Retrieves all formatted variables for the current GraphQL field, + including those from subfields and inline fragments. + """ + formatted_variables = self.formatted_variables.copy() + + # Collect variables from subfields + for subfield in self._subfields: + subfield.get_formatted_variables() + formatted_variables.update(subfield.formatted_variables) + + # Collect variables from inline fragments + for subfields in self._inline_fragments.values(): + for subfield in subfields: + subfield.get_formatted_variables() + formatted_variables.update(subfield.formatted_variables) + return formatted_variables diff --git a/ariadne_codegen/client_generators/package.py b/ariadne_codegen/client_generators/package.py index e17ec5e6..a05cc594 100644 --- a/ariadne_codegen/client_generators/package.py +++ b/ariadne_codegen/client_generators/package.py @@ -2,7 +2,12 @@ from pathlib import Path from typing import Dict, List, Optional, Set -from graphql import FragmentDefinitionNode, GraphQLSchema, OperationDefinitionNode +from graphql import ( + FragmentDefinitionNode, + GraphQLSchema, + OperationDefinitionNode, + OperationType, +) from ..codegen import generate_import_from from ..exceptions import ParsingError @@ -13,9 +18,11 @@ from .client import ClientGenerator from .comments import get_comment from .constants import ( + BASE_GRAPHQL_OPERATION_CLASS_NAME, BASE_MODEL_CLASS_NAME, BASE_MODEL_FILE_PATH, BASE_MODEL_IMPORT, + BASE_OPERATION_FILE_PATH, DEFAULT_ASYNC_BASE_CLIENT_OPEN_TELEMETRY_PATH, DEFAULT_ASYNC_BASE_CLIENT_PATH, DEFAULT_BASE_CLIENT_OPEN_TELEMETRY_PATH, @@ -26,6 +33,9 @@ UPLOAD_CLASS_NAME, UPLOAD_IMPORT, ) +from .custom_fields import CustomFieldsGenerator +from .custom_fields_typing import CustomFieldsTypingGenerator +from .custom_operation import CustomOperationGenerator from .enums import EnumsGenerator from .fragments import FragmentsGenerator from .init_file import InitFileGenerator @@ -45,6 +55,10 @@ def __init__( enums_generator: EnumsGenerator, input_types_generator: InputTypesGenerator, fragments_generator: FragmentsGenerator, + custom_fields_generator: Optional[CustomFieldsGenerator] = None, + custom_fields_typing_generator: Optional[CustomFieldsTypingGenerator] = None, + custom_query_generator: Optional[CustomOperationGenerator] = None, + custom_mutation_generator: Optional[CustomOperationGenerator] = None, fragments_definitions: Optional[Dict[str, FragmentDefinitionNode]] = None, client_name: str = "Client", async_client: bool = True, @@ -54,6 +68,7 @@ def __init__( enums_module_name: str = "enums", input_types_module_name: str = "input_types", fragments_module_name: str = "fragments", + custom_help_field_module_name: str = "custom_typing_fields", comments_strategy: CommentsStrategy = CommentsStrategy.STABLE, queries_source: str = "", schema_source: str = "", @@ -61,12 +76,14 @@ def __init__( include_all_inputs: bool = True, include_all_enums: bool = True, base_model_file_path: str = BASE_MODEL_FILE_PATH.as_posix(), + base_schema_root_file_path: str = BASE_OPERATION_FILE_PATH.as_posix(), base_model_import: ast.ImportFrom = BASE_MODEL_IMPORT, upload_import: ast.ImportFrom = UPLOAD_IMPORT, unset_import: ast.ImportFrom = UNSET_IMPORT, files_to_include: Optional[List[str]] = None, custom_scalars: Optional[Dict[str, ScalarData]] = None, plugin_manager: Optional[PluginManager] = None, + enable_custom_operations: bool = False, ) -> None: self.package_path = Path(target_path) / package_name @@ -80,6 +97,11 @@ def __init__( self.enums_generator = enums_generator self.input_types_generator = input_types_generator self.fragments_generator = fragments_generator + self.custom_fields_generator = custom_fields_generator + self.custom_query_generator = custom_query_generator + self.custom_mutation_generator = custom_mutation_generator + self.custom_fields_typing_generator = custom_fields_typing_generator + self.custom_help_field_module_name = custom_help_field_module_name self.client_name = client_name self.async_client = async_client @@ -104,6 +126,8 @@ def __init__( self.upload_import = upload_import self.unset_import = unset_import + self.base_schema_root_file_path = Path(base_schema_root_file_path) + self.files_to_include = ( [Path(f) for f in files_to_include] if files_to_include else [] ) @@ -115,6 +139,10 @@ def __init__( self._unpacked_fragments: Set[str] = set() self._used_enums: List[str] = [] + self.enable_custom_operations = enable_custom_operations + if self.enable_custom_operations: + self.files_to_include.append(self.base_schema_root_file_path) + def generate(self) -> List[str]: """Generate package with graphql client.""" self._include_exceptions() @@ -125,6 +153,21 @@ def generate(self) -> List[str]: self._generate_result_types() self._generate_fragments() self._copy_files() + if self.enable_custom_operations: + self._generate_custom_fields_typing() + self._generate_custom_fields() + self.client_generator.add_execute_custom_operation_method() + if self.custom_query_generator: + self._generate_custom_queries() + self.client_generator.create_custom_operation_method( + "query", OperationType.QUERY.value.upper() + ) + if self.custom_mutation_generator: + self._generate_custom_mutations() + self.client_generator.create_custom_operation_method( + "mutation", OperationType.MUTATION.value.upper() + ) + self._generate_client() self._generate_enums() self._generate_init() @@ -335,6 +378,34 @@ def _generate_init(self): init_file_path.write_text(code) self._generated_files.append(init_file_path.name) + def _generate_custom_queries(self): + file_path = self.package_path / "custom_queries.py" + module = self.custom_query_generator.generate() + code = self._add_comments_to_code(ast_to_str(module, False)) + file_path.write_text(code) + self._generated_files.append(file_path.name) + + def _generate_custom_mutations(self): + file_path = self.package_path / "custom_mutations.py" + module = self.custom_mutation_generator.generate() + code = self._add_comments_to_code(ast_to_str(module, False)) + file_path.write_text(code) + self._generated_files.append(file_path.name) + + def _generate_custom_fields_typing(self): + file_path = self.package_path / "custom_typing_fields.py" + module = self.custom_fields_typing_generator.generate() + code = self._add_comments_to_code(ast_to_str(module, False)) + file_path.write_text(code) + self._generated_files.append(file_path.name) + + def _generate_custom_fields(self): + file_path = self.package_path / "custom_fields.py" + module = self.custom_fields_generator.generate() + code = self._add_comments_to_code(ast_to_str(module, False)) + file_path.write_text(code) + self._generated_files.append(file_path.name) + def get_package_generator( schema: GraphQLSchema, @@ -384,6 +455,42 @@ def get_package_generator( custom_scalars=settings.scalars, plugin_manager=plugin_manager, ) + custom_fields_generator = CustomFieldsGenerator( + schema=schema, custom_scalars=settings.scalars, plugin_manager=plugin_manager + ) + custom_fields_typing_generator = CustomFieldsTypingGenerator(schema=schema) + custom_query_generator = None + if schema.query_type: + custom_query_generator = CustomOperationGenerator( + graphql_fields=schema.query_type.fields, + name="Query", + base_name=BASE_GRAPHQL_OPERATION_CLASS_NAME, + enums_module_name=settings.enums_module_name, + custom_scalars=settings.scalars, + plugin_manager=plugin_manager, + arguments_generator=ArgumentsGenerator( + schema=schema, + convert_to_snake_case=settings.convert_to_snake_case, + custom_scalars=settings.scalars, + plugin_manager=plugin_manager, + ), + ) + custom_mutation_generator = None + if schema.mutation_type: + custom_mutation_generator = CustomOperationGenerator( + graphql_fields=schema.mutation_type.fields, + name="Mutation", + base_name=BASE_GRAPHQL_OPERATION_CLASS_NAME, + enums_module_name=settings.enums_module_name, + custom_scalars=settings.scalars, + plugin_manager=plugin_manager, + arguments_generator=ArgumentsGenerator( + schema=schema, + convert_to_snake_case=settings.convert_to_snake_case, + custom_scalars=settings.scalars, + plugin_manager=plugin_manager, + ), + ) return PackageGenerator( package_name=settings.target_package_name, @@ -403,6 +510,10 @@ def get_package_generator( enums_module_name=settings.enums_module_name, input_types_module_name=settings.input_types_module_name, fragments_module_name=settings.fragments_module_name, + custom_fields_generator=custom_fields_generator, + custom_fields_typing_generator=custom_fields_typing_generator, + custom_query_generator=custom_query_generator, + custom_mutation_generator=custom_mutation_generator, comments_strategy=settings.include_comments, queries_source=settings.queries_path, schema_source=settings.schema_source, @@ -416,4 +527,5 @@ def get_package_generator( files_to_include=settings.files_to_include, custom_scalars=settings.scalars, plugin_manager=plugin_manager, + enable_custom_operations=settings.enable_custom_operations, ) diff --git a/ariadne_codegen/codegen.py b/ariadne_codegen/codegen.py index 40a8b894..c4893511 100644 --- a/ariadne_codegen/codegen.py +++ b/ariadne_codegen/codegen.py @@ -24,8 +24,13 @@ from .exceptions import ParsingError +def generate_import(names: List[str], level: int = 0) -> ast.Import: + """Generate import statement.""" + return ast.Import(names=[ast.alias(n) for n in names], level=level) + + def generate_import_from( - names: List[str], from_: str, level: int = 0 + names: List[str], from_: Optional[str] = None, level: int = 0 ) -> ast.ImportFrom: """Generate import from statement.""" return ast.ImportFrom( @@ -34,7 +39,7 @@ def generate_import_from( def generate_nullable_annotation( - slice_: Union[ast.Name, ast.Subscript] + slice_: Union[ast.Name, ast.Subscript], ) -> ast.Subscript: """Generate optional annotation.""" return ast.Subscript(value=ast.Name(id=OPTIONAL), slice=slice_) @@ -65,15 +70,19 @@ def generate_arg( def generate_arguments( args: Optional[List[ast.arg]] = None, - defaults: Optional[List[ast.expr]] = None, + vararg: Optional[ast.arg] = None, + kwonlyargs: Optional[list[ast.arg]] = None, + kw_defaults: Optional[list[Union[ast.expr, None]]] = None, kwarg: Optional[ast.arg] = None, + defaults: Optional[List[ast.expr]] = None, ) -> ast.arguments: """Generate arguments.""" return ast.arguments( posonlyargs=[], args=args if args else [], - kwonlyargs=[], - kw_defaults=[], + vararg=vararg, + kwonlyargs=kwonlyargs if kwonlyargs else [], + kw_defaults=kw_defaults if kw_defaults else [], kwarg=kwarg, defaults=defaults or [], ) @@ -85,13 +94,14 @@ def generate_async_method_definition( return_type: Union[ast.Name, ast.Subscript], body: Optional[List[ast.stmt]] = None, lineno: int = 1, + decorator_list: Optional[List[ast.Name]] = None, ) -> ast.AsyncFunctionDef: """Generate async function.""" return ast.AsyncFunctionDef( name=name, args=arguments, body=body if body else [ast.Pass()], - decorator_list=[], + decorator_list=decorator_list if decorator_list else [], returns=return_type, lineno=lineno, ) @@ -118,11 +128,26 @@ def generate_name(name: str) -> ast.Name: return ast.Name(id=name) +def generate_joined_str(values: list[ast.expr]) -> ast.JoinedStr: + """Generate joined str object.""" + return ast.JoinedStr(values) + + def generate_constant(value: Any) -> ast.Constant: """Generate constant object.""" return ast.Constant(value=value) +def generate_formatted_value( + value: ast.expr, conversion: int = -1, format_spec: Optional[ast.expr] = None +) -> ast.FormattedValue: + return ast.FormattedValue( + value=value, + conversion=conversion, + format_spec=format_spec, + ) + + def generate_assign( targets: List[str], value: Union[ast.expr, List[ast.expr]], lineno: int = 1 ) -> ast.Assign: @@ -267,6 +292,25 @@ def generate_list(elements: List[Optional[ast.expr]]) -> ast.List: return ast.List(elts=elements) +def generate_list_comp( + elt: ast.expr, generators: list[ast.comprehension] +) -> ast.ListComp: + """Generate list comprehension""" + return ast.ListComp(elt=elt, generators=generators) + + +def generate_comp( + target: str, iter_: str, ifs: Optional[List[ast.expr]] = None, is_async: int = 0 +) -> ast.comprehension: + "Generate comprehension" + return ast.comprehension( + target=generate_name(target), + iter=generate_name(iter_), + ifs=ifs if ifs else [], + is_async=is_async, + ) + + def generate_lambda(body: ast.expr, args: Optional[ast.arguments] = None) -> ast.Lambda: """Generate lambda definition.""" return ast.Lambda(args=args or generate_arguments(), body=body) @@ -299,12 +343,13 @@ def generate_method_definition( return_type: Union[ast.Name, ast.Subscript], body: Optional[List[ast.stmt]] = None, lineno: int = 1, + decorator_list: Optional[List[ast.Name]] = None, ) -> ast.FunctionDef: return ast.FunctionDef( name=name, args=arguments, body=body if body else [ast.Pass()], - decorator_list=[], + decorator_list=decorator_list if decorator_list else [], returns=return_type, lineno=lineno, ) diff --git a/ariadne_codegen/main.py b/ariadne_codegen/main.py index 57eb60c5..2dcd603c 100644 --- a/ariadne_codegen/main.py +++ b/ariadne_codegen/main.py @@ -61,9 +61,12 @@ def client(config_dict): schema = plugin_manager.process_schema(schema) assert_valid_schema(schema) - definitions = get_graphql_queries(settings.queries_path, schema) - queries = filter_operations_definitions(definitions) - fragments = filter_fragments_definitions(definitions) + fragments = [] + queries = [] + if settings.queries_path: + definitions = get_graphql_queries(settings.queries_path, schema) + queries = filter_operations_definitions(definitions) + fragments = filter_fragments_definitions(definitions) sys.stdout.write(settings.used_settings_message) diff --git a/ariadne_codegen/settings.py b/ariadne_codegen/settings.py index fbd8f427..808397ba 100644 --- a/ariadne_codegen/settings.py +++ b/ariadne_codegen/settings.py @@ -37,6 +37,7 @@ class BaseSettings: remote_schema_url: str = "" remote_schema_headers: dict = field(default_factory=dict) remote_schema_verify_ssl: bool = True + enable_custom_operations: bool = False plugins: List[str] = field(default_factory=list) def __post_init__(self): @@ -73,7 +74,7 @@ class ClientSettings(BaseSettings): scalars: Dict[str, ScalarData] = field(default_factory=dict) def __post_init__(self): - if not self.queries_path: + if not self.queries_path and not self.enable_custom_operations: raise TypeError("__init__ missing 1 required argument: 'queries_path'") super().__post_init__() diff --git a/pyproject.toml b/pyproject.toml index 89f8bb40..de4589c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "hatchling.build" name = "ariadne-codegen" description = "Generate fully typed GraphQL client from schema, queries and mutations!" authors = [{ name = "Mirumee Software", email = "hello@mirumee.com" }] -version = "0.13.0" +version = "0.14.0" readme = "README.md" license = { file = "LICENSE" } classifiers = [ diff --git a/tests/main/clients/custom_query_builder/expected_client/__init__.py b/tests/main/clients/custom_query_builder/expected_client/__init__.py new file mode 100644 index 00000000..99fb757c --- /dev/null +++ b/tests/main/clients/custom_query_builder/expected_client/__init__.py @@ -0,0 +1,24 @@ +from .async_base_client import AsyncBaseClient +from .base_model import BaseModel, Upload +from .client import Client +from .enums import MetadataErrorCode +from .exceptions import ( + GraphQLClientError, + GraphQLClientGraphQLError, + GraphQLClientGraphQLMultiError, + GraphQLClientHttpError, + GraphQLClientInvalidResponseError, +) + +__all__ = [ + "AsyncBaseClient", + "BaseModel", + "Client", + "GraphQLClientError", + "GraphQLClientGraphQLError", + "GraphQLClientGraphQLMultiError", + "GraphQLClientHttpError", + "GraphQLClientInvalidResponseError", + "MetadataErrorCode", + "Upload", +] diff --git a/tests/main/clients/custom_query_builder/expected_client/async_base_client.py b/tests/main/clients/custom_query_builder/expected_client/async_base_client.py new file mode 100644 index 00000000..5358ced6 --- /dev/null +++ b/tests/main/clients/custom_query_builder/expected_client/async_base_client.py @@ -0,0 +1,370 @@ +import enum +import json +from typing import IO, Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, cast +from uuid import uuid4 + +import httpx +from pydantic import BaseModel +from pydantic_core import to_jsonable_python + +from .base_model import UNSET, Upload +from .exceptions import ( + GraphQLClientGraphQLMultiError, + GraphQLClientHttpError, + GraphQLClientInvalidMessageFormat, + GraphQLClientInvalidResponseError, +) + +try: + from websockets.client import ( # type: ignore[import-not-found,unused-ignore] + WebSocketClientProtocol, + connect as ws_connect, + ) + from websockets.typing import ( # type: ignore[import-not-found,unused-ignore] + Data, + Origin, + Subprotocol, + ) +except ImportError: + from contextlib import asynccontextmanager + + @asynccontextmanager # type: ignore + async def ws_connect(*args, **kwargs): # pylint: disable=unused-argument + raise NotImplementedError("Subscriptions require 'websockets' package.") + yield # pylint: disable=unreachable + + WebSocketClientProtocol = Any # type: ignore[misc,assignment,unused-ignore] + Data = Any # type: ignore[misc,assignment,unused-ignore] + Origin = Any # type: ignore[misc,assignment,unused-ignore] + + def Subprotocol(*args, **kwargs): # type: ignore # pylint: disable=invalid-name + raise NotImplementedError("Subscriptions require 'websockets' package.") + + +Self = TypeVar("Self", bound="AsyncBaseClient") + +GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" + + +class GraphQLTransportWSMessageType(str, enum.Enum): + CONNECTION_INIT = "connection_init" + CONNECTION_ACK = "connection_ack" + PING = "ping" + PONG = "pong" + SUBSCRIBE = "subscribe" + NEXT = "next" + ERROR = "error" + COMPLETE = "complete" + + +class AsyncBaseClient: + def __init__( + self, + url: str = "", + headers: Optional[Dict[str, str]] = None, + http_client: Optional[httpx.AsyncClient] = None, + ws_url: str = "", + ws_headers: Optional[Dict[str, Any]] = None, + ws_origin: Optional[str] = None, + ws_connection_init_payload: Optional[Dict[str, Any]] = None, + ) -> None: + self.url = url + self.headers = headers + self.http_client = ( + http_client if http_client else httpx.AsyncClient(headers=headers) + ) + + self.ws_url = ws_url + self.ws_headers = ws_headers or {} + self.ws_origin = Origin(ws_origin) if ws_origin else None + self.ws_connection_init_payload = ws_connection_init_payload + + async def __aenter__(self: Self) -> Self: + return self + + async def __aexit__( + self, + exc_type: object, + exc_val: object, + exc_tb: object, + ) -> None: + await self.http_client.aclose() + + async def execute( + self, + query: str, + operation_name: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> httpx.Response: + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart( + query=query, + operation_name=operation_name, + variables=processed_variables, + files=files, + files_map=files_map, + **kwargs, + ) + + return await self._execute_json( + query=query, + operation_name=operation_name, + variables=processed_variables, + **kwargs, + ) + + def get_data(self, response: httpx.Response) -> Dict[str, Any]: + if not response.is_success: + raise GraphQLClientHttpError( + status_code=response.status_code, response=response + ) + + try: + response_json = response.json() + except ValueError as exc: + raise GraphQLClientInvalidResponseError(response=response) from exc + + if (not isinstance(response_json, dict)) or ( + "data" not in response_json and "errors" not in response_json + ): + raise GraphQLClientInvalidResponseError(response=response) + + data = response_json.get("data") + errors = response_json.get("errors") + + if errors: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=errors, data=data + ) + + return cast(Dict[str, Any], data) + + async def execute_ws( + self, + query: str, + operation_name: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> AsyncIterator[Dict[str, Any]]: + headers = self.ws_headers.copy() + headers.update(kwargs.get("extra_headers", {})) + + merged_kwargs: Dict[str, Any] = {"origin": self.ws_origin} + merged_kwargs.update(kwargs) + merged_kwargs["extra_headers"] = headers + + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + **merged_kwargs, + ) as websocket: + await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) + await self._send_subscribe( + websocket, + operation_id=operation_id, + query=query, + operation_name=operation_name, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message(message, websocket) + if data: + yield data + + def _process_variables( + self, variables: Optional[Dict[str, Any]] + ) -> Tuple[ + Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]] + ]: + if not variables: + return {}, {}, {} + + serializable_variables = self._convert_dict_to_json_serializable(variables) + return self._get_files_from_variables(serializable_variables) + + def _convert_dict_to_json_serializable( + self, dict_: Dict[str, Any] + ) -> Dict[str, Any]: + return { + key: self._convert_value(value) + for key, value in dict_.items() + if value is not UNSET + } + + def _convert_value(self, value: Any) -> Any: + if isinstance(value, BaseModel): + return value.model_dump(by_alias=True, exclude_unset=True) + if isinstance(value, list): + return [self._convert_value(item) for item in value] + return value + + def _get_files_from_variables( + self, variables: Dict[str, Any] + ) -> Tuple[ + Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]] + ]: + files_map: Dict[str, List[str]] = {} + files_list: List[Upload] = [] + + def separate_files(path: str, obj: Any) -> Any: + if isinstance(obj, list): + nulled_list = [] + for index, value in enumerate(obj): + value = separate_files(f"{path}.{index}", value) + nulled_list.append(value) + return nulled_list + + if isinstance(obj, dict): + nulled_dict = {} + for key, value in obj.items(): + value = separate_files(f"{path}.{key}", value) + nulled_dict[key] = value + return nulled_dict + + if isinstance(obj, Upload): + if obj in files_list: + file_index = files_list.index(obj) + files_map[str(file_index)].append(path) + else: + file_index = len(files_list) + files_list.append(obj) + files_map[str(file_index)] = [path] + return None + + return obj + + nulled_variables = separate_files("variables", variables) + files: Dict[str, Tuple[str, IO[bytes], str]] = { + str(i): (file_.filename, cast(IO[bytes], file_.content), file_.content_type) + for i, file_ in enumerate(files_list) + } + return nulled_variables, files, files_map + + async def _execute_multipart( + self, + query: str, + operation_name: Optional[str], + variables: Dict[str, Any], + files: Dict[str, Tuple[str, IO[bytes], str]], + files_map: Dict[str, List[str]], + **kwargs: Any, + ) -> httpx.Response: + data = { + "operations": json.dumps( + { + "query": query, + "operationName": operation_name, + "variables": variables, + }, + default=to_jsonable_python, + ), + "map": json.dumps(files_map, default=to_jsonable_python), + } + + return await self.http_client.post( + url=self.url, data=data, files=files, **kwargs + ) + + async def _execute_json( + self, + query: str, + operation_name: Optional[str], + variables: Dict[str, Any], + **kwargs: Any, + ) -> httpx.Response: + headers: Dict[str, str] = {"Content-Type": "application/json"} + headers.update(kwargs.get("headers", {})) + + merged_kwargs: Dict[str, Any] = kwargs.copy() + merged_kwargs["headers"] = headers + + return await self.http_client.post( + url=self.url, + content=json.dumps( + { + "query": query, + "operationName": operation_name, + "variables": variables, + }, + default=to_jsonable_python, + ), + **merged_kwargs, + ) + + async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: + payload: Dict[str, Any] = { + "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value + } + if self.ws_connection_init_payload: + payload["payload"] = self.ws_connection_init_payload + await websocket.send(json.dumps(payload)) + + async def _send_subscribe( + self, + websocket: WebSocketClientProtocol, + operation_id: str, + query: str, + operation_name: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + ) -> None: + payload: Dict[str, Any] = { + "id": operation_id, + "type": GraphQLTransportWSMessageType.SUBSCRIBE.value, + "payload": {"query": query, "operationName": operation_name}, + } + if variables: + payload["payload"]["variables"] = self._convert_dict_to_json_serializable( + variables + ) + await websocket.send(json.dumps(payload)) + + async def _handle_ws_message( + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: Optional[GraphQLTransportWSMessageType] = None, + ) -> Optional[Dict[str, Any]]: + try: + message_dict = json.loads(message) + except json.JSONDecodeError as exc: + raise GraphQLClientInvalidMessageFormat(message=message) from exc + + type_ = message_dict.get("type") + payload = message_dict.get("payload", {}) + + if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: + raise GraphQLClientInvalidMessageFormat(message=message) + + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message received. Expected: {expected_type.value}" + ) + + if type_ == GraphQLTransportWSMessageType.NEXT: + if "data" not in payload: + raise GraphQLClientInvalidMessageFormat(message=message) + return cast(Dict[str, Any], payload["data"]) + + if type_ == GraphQLTransportWSMessageType.COMPLETE: + await websocket.close() + elif type_ == GraphQLTransportWSMessageType.PING: + await websocket.send( + json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) + ) + elif type_ == GraphQLTransportWSMessageType.ERROR: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=payload, data=message_dict + ) + + return None diff --git a/tests/main/clients/custom_query_builder/expected_client/base_model.py b/tests/main/clients/custom_query_builder/expected_client/base_model.py new file mode 100644 index 00000000..ccde3975 --- /dev/null +++ b/tests/main/clients/custom_query_builder/expected_client/base_model.py @@ -0,0 +1,27 @@ +from io import IOBase + +from pydantic import BaseModel as PydanticBaseModel, ConfigDict + + +class UnsetType: + def __bool__(self) -> bool: + return False + + +UNSET = UnsetType() + + +class BaseModel(PydanticBaseModel): + model_config = ConfigDict( + populate_by_name=True, + validate_assignment=True, + arbitrary_types_allowed=True, + protected_namespaces=(), + ) + + +class Upload: + def __init__(self, filename: str, content: IOBase, content_type: str): + self.filename = filename + self.content = content + self.content_type = content_type diff --git a/tests/main/clients/custom_query_builder/expected_client/base_operation.py b/tests/main/clients/custom_query_builder/expected_client/base_operation.py new file mode 100644 index 00000000..9f9c7660 --- /dev/null +++ b/tests/main/clients/custom_query_builder/expected_client/base_operation.py @@ -0,0 +1,156 @@ +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +from graphql import ( + ArgumentNode, + FieldNode, + InlineFragmentNode, + NamedTypeNode, + NameNode, + SelectionSetNode, + VariableNode, +) + + +class GraphQLArgument: + """ + Represents a GraphQL argument and allows conversion to an AST structure. + """ + + def __init__(self, argument_name: str, argument_value: Any) -> None: + self._name = argument_name + self._value = argument_value + + def to_ast(self) -> ArgumentNode: + """Converts the argument to an ArgumentNode AST object.""" + return ArgumentNode( + name=NameNode(value=self._name), + value=VariableNode(name=NameNode(value=self._value)), + ) + + +class GraphQLField: + """ + Represents a GraphQL field with its name, arguments, subfields, alias, + and inline fragments. + + Attributes: + formatted_variables (Dict[str, Dict[str, Any]]): The formatted arguments + of the GraphQL field. + """ + + def __init__( + self, field_name: str, arguments: Optional[Dict[str, Dict[str, Any]]] = None + ) -> None: + self._field_name = field_name + self._variables = arguments or {} + self.formatted_variables: Dict[str, Dict[str, Any]] = {} + self._subfields: List[GraphQLField] = [] + self._alias: Optional[str] = None + self._inline_fragments: Dict[str, Tuple[GraphQLField, ...]] = {} + + def alias(self, alias: str) -> "GraphQLField": + """Sets an alias for the GraphQL field and returns the instance.""" + self._alias = alias + return self + + def _build_field_name(self) -> str: + """Builds the field name, including the alias if present.""" + return f"{self._alias}: {self._field_name}" if self._alias else self._field_name + + def _build_selections( + self, idx: int, used_names: Set[str] + ) -> List[Union[FieldNode, InlineFragmentNode]]: + """Builds the selection set for the current GraphQL field, + including subfields and inline fragments.""" + # Create selections from subfields + selections: List[Union[FieldNode, InlineFragmentNode]] = [ + subfield.to_ast(idx, used_names) for subfield in self._subfields + ] + + # Add inline fragments + for name, subfields in self._inline_fragments.items(): + selections.append( + InlineFragmentNode( + type_condition=NamedTypeNode(name=NameNode(value=name)), + selection_set=SelectionSetNode( + selections=[ + subfield.to_ast(idx, used_names) for subfield in subfields + ] + ), + ) + ) + + return selections + + def _format_variable_name( + self, idx: int, var_name: str, used_names: Set[str] + ) -> str: + """Generates a unique variable name by appending an index and, + if necessary, an additional counter to avoid duplicates.""" + base_name = f"{var_name}_{idx}" + unique_name = base_name + counter = 1 + + # Ensure the generated name is unique + while unique_name in used_names: + unique_name = f"{base_name}_{counter}" + counter += 1 + + # Add the unique name to the set of used names + used_names.add(unique_name) + + return unique_name + + def _collect_all_variables(self, idx: int, used_names: Set[str]) -> None: + """ + Collects and formats all variables for the current GraphQL field, + ensuring unique names. + """ + self.formatted_variables = {} + + for k, v in self._variables.items(): + unique_name = self._format_variable_name(idx, k, used_names) + self.formatted_variables[unique_name] = { + "name": k, + "type": v["type"], + "value": v["value"], + } + + def to_ast(self, idx: int, used_names: Optional[Set[str]] = None) -> FieldNode: + """Converts the current GraphQL field to an AST (Abstract Syntax Tree) node.""" + if used_names is None: + used_names = set() + + self._collect_all_variables(idx, used_names) + + return FieldNode( + name=NameNode(value=self._build_field_name()), + arguments=[ + GraphQLArgument(v["name"], k).to_ast() + for k, v in self.formatted_variables.items() + ], + selection_set=( + SelectionSetNode(selections=self._build_selections(idx, used_names)) + if self._subfields or self._inline_fragments + else None + ), + ) + + def get_formatted_variables(self) -> Dict[str, Dict[str, Any]]: + """ + Retrieves all formatted variables for the current GraphQL field, + including those from subfields and inline fragments. + """ + formatted_variables = self.formatted_variables.copy() + + # Collect variables from subfields + for subfield in self._subfields: + subfield.get_formatted_variables() + formatted_variables.update(subfield.formatted_variables) + + # Collect variables from inline fragments + for subfields in self._inline_fragments.values(): + for subfield in subfields: + subfield.get_formatted_variables() + formatted_variables.update(subfield.formatted_variables) + return formatted_variables diff --git a/tests/main/clients/custom_query_builder/expected_client/client.py b/tests/main/clients/custom_query_builder/expected_client/client.py new file mode 100644 index 00000000..c89fa2c6 --- /dev/null +++ b/tests/main/clients/custom_query_builder/expected_client/client.py @@ -0,0 +1,107 @@ +from typing import Any, Dict, List, Tuple + +from graphql import ( + DocumentNode, + NamedTypeNode, + NameNode, + OperationDefinitionNode, + OperationType, + SelectionNode, + SelectionSetNode, + VariableDefinitionNode, + VariableNode, + print_ast, +) + +from .async_base_client import AsyncBaseClient +from .base_operation import GraphQLField + + +def gql(q: str) -> str: + return q + + +class Client(AsyncBaseClient): + async def execute_custom_operation( + self, *fields: GraphQLField, operation_type: OperationType, operation_name: str + ) -> Dict[str, Any]: + selections = self._build_selection_set(fields) + combined_variables = self._combine_variables(fields) + variable_definitions = self._build_variable_definitions( + combined_variables["types"] + ) + operation_ast = self._build_operation_ast( + selections, operation_type, operation_name, variable_definitions + ) + response = await self.execute( + print_ast(operation_ast), + variables=combined_variables["values"], + operation_name=operation_name, + ) + return self.get_data(response) + + def _combine_variables( + self, fields: Tuple[GraphQLField, ...] + ) -> Dict[str, Dict[str, Any]]: + variables_types_combined = {} + processed_variables_combined = {} + for field in fields: + formatted_variables = field.get_formatted_variables() + variables_types_combined.update( + {k: v["type"] for k, v in formatted_variables.items()} + ) + processed_variables_combined.update( + {k: v["value"] for k, v in formatted_variables.items()} + ) + return { + "types": variables_types_combined, + "values": processed_variables_combined, + } + + def _build_variable_definitions( + self, variables_types_combined: Dict[str, str] + ) -> List[VariableDefinitionNode]: + return [ + VariableDefinitionNode( + variable=VariableNode(name=NameNode(value=var_name)), + type=NamedTypeNode(name=NameNode(value=var_value)), + ) + for var_name, var_value in variables_types_combined.items() + ] + + def _build_operation_ast( + self, + selections: List[SelectionNode], + operation_type: OperationType, + operation_name: str, + variable_definitions: List[VariableDefinitionNode], + ) -> DocumentNode: + return DocumentNode( + definitions=[ + OperationDefinitionNode( + operation=operation_type, + name=NameNode(value=operation_name), + variable_definitions=variable_definitions, + selection_set=SelectionSetNode(selections=selections), + ) + ] + ) + + def _build_selection_set( + self, fields: Tuple[GraphQLField, ...] + ) -> List[SelectionNode]: + return [field.to_ast(idx) for idx, field in enumerate(fields)] + + async def query(self, *fields: GraphQLField, operation_name: str) -> Dict[str, Any]: + return await self.execute_custom_operation( + *fields, operation_type=OperationType.QUERY, operation_name=operation_name + ) + + async def mutation( + self, *fields: GraphQLField, operation_name: str + ) -> Dict[str, Any]: + return await self.execute_custom_operation( + *fields, + operation_type=OperationType.MUTATION, + operation_name=operation_name + ) diff --git a/tests/main/clients/custom_query_builder/expected_client/custom_fields.py b/tests/main/clients/custom_query_builder/expected_client/custom_fields.py new file mode 100644 index 00000000..799b8391 --- /dev/null +++ b/tests/main/clients/custom_query_builder/expected_client/custom_fields.py @@ -0,0 +1,384 @@ +from typing import Any, Dict, Optional, Union + +from .base_operation import GraphQLField +from .custom_typing_fields import ( + AppGraphQLField, + CollectionTranslatableContentGraphQLField, + MetadataErrorGraphQLField, + MetadataItemGraphQLField, + ObjectWithMetadataGraphQLField, + PageInfoGraphQLField, + ProductCountableConnectionGraphQLField, + ProductCountableEdgeGraphQLField, + ProductGraphQLField, + ProductTranslatableContentGraphQLField, + ProductTypeCountableConnectionGraphQLField, + TranslatableItemConnectionGraphQLField, + TranslatableItemEdgeGraphQLField, + TranslatableItemUnion, + UpdateMetadataGraphQLField, +) + + +class AppFields(GraphQLField): + id: "AppGraphQLField" = AppGraphQLField("id") + + def fields(self, *subfields: AppGraphQLField) -> "AppFields": + """Subfields should come from the AppFields class""" + self._subfields.extend(subfields) + return self + + def alias(self, alias: str) -> "AppFields": + self._alias = alias + return self + + +class CollectionTranslatableContentFields(GraphQLField): + id: "CollectionTranslatableContentGraphQLField" = ( + CollectionTranslatableContentGraphQLField("id") + ) + collection_id: "CollectionTranslatableContentGraphQLField" = ( + CollectionTranslatableContentGraphQLField("collectionId") + ) + seo_title: "CollectionTranslatableContentGraphQLField" = ( + CollectionTranslatableContentGraphQLField("seoTitle") + ) + seo_description: "CollectionTranslatableContentGraphQLField" = ( + CollectionTranslatableContentGraphQLField("seoDescription") + ) + name: "CollectionTranslatableContentGraphQLField" = ( + CollectionTranslatableContentGraphQLField("name") + ) + description: "CollectionTranslatableContentGraphQLField" = ( + CollectionTranslatableContentGraphQLField("description") + ) + + def fields( + self, *subfields: CollectionTranslatableContentGraphQLField + ) -> "CollectionTranslatableContentFields": + """Subfields should come from the CollectionTranslatableContentFields class""" + self._subfields.extend(subfields) + return self + + def alias(self, alias: str) -> "CollectionTranslatableContentFields": + self._alias = alias + return self + + +class MetadataErrorFields(GraphQLField): + field: "MetadataErrorGraphQLField" = MetadataErrorGraphQLField("field") + message: "MetadataErrorGraphQLField" = MetadataErrorGraphQLField("message") + code: "MetadataErrorGraphQLField" = MetadataErrorGraphQLField("code") + + def fields(self, *subfields: MetadataErrorGraphQLField) -> "MetadataErrorFields": + """Subfields should come from the MetadataErrorFields class""" + self._subfields.extend(subfields) + return self + + def alias(self, alias: str) -> "MetadataErrorFields": + self._alias = alias + return self + + +class MetadataItemFields(GraphQLField): + key: "MetadataItemGraphQLField" = MetadataItemGraphQLField("key") + value: "MetadataItemGraphQLField" = MetadataItemGraphQLField("value") + + def fields(self, *subfields: MetadataItemGraphQLField) -> "MetadataItemFields": + """Subfields should come from the MetadataItemFields class""" + self._subfields.extend(subfields) + return self + + def alias(self, alias: str) -> "MetadataItemFields": + self._alias = alias + return self + + +class ObjectWithMetadataInterface(GraphQLField): + @classmethod + def private_metadata(cls) -> "MetadataItemFields": + return MetadataItemFields("private_metadata") + + @classmethod + def private_metafield(cls, key: str) -> "ObjectWithMetadataGraphQLField": + arguments: Dict[str, Dict[str, Any]] = { + "key": {"type": "String!", "value": key} + } + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return ObjectWithMetadataGraphQLField( + "private_metafield", arguments=cleared_arguments + ) + + @classmethod + def metadata(cls) -> "MetadataItemFields": + return MetadataItemFields("metadata") + + @classmethod + def metafield(cls, key: str) -> "ObjectWithMetadataGraphQLField": + arguments: Dict[str, Dict[str, Any]] = { + "key": {"type": "String!", "value": key} + } + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return ObjectWithMetadataGraphQLField("metafield", arguments=cleared_arguments) + + def fields( + self, *subfields: Union[ObjectWithMetadataGraphQLField, "MetadataItemFields"] + ) -> "ObjectWithMetadataInterface": + """Subfields should come from the ObjectWithMetadataInterface class""" + self._subfields.extend(subfields) + return self + + def alias(self, alias: str) -> "ObjectWithMetadataInterface": + self._alias = alias + return self + + def on( + self, type_name: str, *subfields: GraphQLField + ) -> "ObjectWithMetadataInterface": + self._inline_fragments[type_name] = subfields + return self + + +class PageInfoFields(GraphQLField): + has_next_page: "PageInfoGraphQLField" = PageInfoGraphQLField("hasNextPage") + has_previous_page: "PageInfoGraphQLField" = PageInfoGraphQLField("hasPreviousPage") + start_cursor: "PageInfoGraphQLField" = PageInfoGraphQLField("startCursor") + end_cursor: "PageInfoGraphQLField" = PageInfoGraphQLField("endCursor") + + def fields(self, *subfields: PageInfoGraphQLField) -> "PageInfoFields": + """Subfields should come from the PageInfoFields class""" + self._subfields.extend(subfields) + return self + + def alias(self, alias: str) -> "PageInfoFields": + self._alias = alias + return self + + +class ProductFields(GraphQLField): + id: "ProductGraphQLField" = ProductGraphQLField("id") + slug: "ProductGraphQLField" = ProductGraphQLField("slug") + name: "ProductGraphQLField" = ProductGraphQLField("name") + + @classmethod + def private_metadata(cls) -> "MetadataItemFields": + return MetadataItemFields("private_metadata") + + @classmethod + def private_metafield(cls, key: str) -> "ProductGraphQLField": + arguments: Dict[str, Dict[str, Any]] = { + "key": {"type": "String!", "value": key} + } + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return ProductGraphQLField("private_metafield", arguments=cleared_arguments) + + @classmethod + def metadata(cls) -> "MetadataItemFields": + return MetadataItemFields("metadata") + + @classmethod + def metafield(cls, key: str) -> "ProductGraphQLField": + arguments: Dict[str, Dict[str, Any]] = { + "key": {"type": "String!", "value": key} + } + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return ProductGraphQLField("metafield", arguments=cleared_arguments) + + def fields( + self, *subfields: Union[ProductGraphQLField, "MetadataItemFields"] + ) -> "ProductFields": + """Subfields should come from the ProductFields class""" + self._subfields.extend(subfields) + return self + + def alias(self, alias: str) -> "ProductFields": + self._alias = alias + return self + + +class ProductCountableConnectionFields(GraphQLField): + @classmethod + def edges(cls) -> "ProductCountableEdgeFields": + return ProductCountableEdgeFields("edges") + + @classmethod + def page_info(cls) -> "PageInfoFields": + return PageInfoFields("page_info") + + total_count: "ProductCountableConnectionGraphQLField" = ( + ProductCountableConnectionGraphQLField("totalCount") + ) + + def fields( + self, + *subfields: Union[ + ProductCountableConnectionGraphQLField, + "PageInfoFields", + "ProductCountableEdgeFields", + ] + ) -> "ProductCountableConnectionFields": + """Subfields should come from the ProductCountableConnectionFields class""" + self._subfields.extend(subfields) + return self + + def alias(self, alias: str) -> "ProductCountableConnectionFields": + self._alias = alias + return self + + +class ProductCountableEdgeFields(GraphQLField): + @classmethod + def node(cls) -> "ProductFields": + return ProductFields("node") + + cursor: "ProductCountableEdgeGraphQLField" = ProductCountableEdgeGraphQLField( + "cursor" + ) + + def fields( + self, *subfields: Union[ProductCountableEdgeGraphQLField, "ProductFields"] + ) -> "ProductCountableEdgeFields": + """Subfields should come from the ProductCountableEdgeFields class""" + self._subfields.extend(subfields) + return self + + def alias(self, alias: str) -> "ProductCountableEdgeFields": + self._alias = alias + return self + + +class ProductTranslatableContentFields(GraphQLField): + id: "ProductTranslatableContentGraphQLField" = ( + ProductTranslatableContentGraphQLField("id") + ) + product_id: "ProductTranslatableContentGraphQLField" = ( + ProductTranslatableContentGraphQLField("productId") + ) + seo_title: "ProductTranslatableContentGraphQLField" = ( + ProductTranslatableContentGraphQLField("seoTitle") + ) + seo_description: "ProductTranslatableContentGraphQLField" = ( + ProductTranslatableContentGraphQLField("seoDescription") + ) + name: "ProductTranslatableContentGraphQLField" = ( + ProductTranslatableContentGraphQLField("name") + ) + description: "ProductTranslatableContentGraphQLField" = ( + ProductTranslatableContentGraphQLField("description") + ) + + def fields( + self, *subfields: ProductTranslatableContentGraphQLField + ) -> "ProductTranslatableContentFields": + """Subfields should come from the ProductTranslatableContentFields class""" + self._subfields.extend(subfields) + return self + + def alias(self, alias: str) -> "ProductTranslatableContentFields": + self._alias = alias + return self + + +class ProductTypeCountableConnectionFields(GraphQLField): + @classmethod + def page_info(cls) -> "PageInfoFields": + return PageInfoFields("page_info") + + def fields( + self, + *subfields: Union[ProductTypeCountableConnectionGraphQLField, "PageInfoFields"] + ) -> "ProductTypeCountableConnectionFields": + """Subfields should come from the ProductTypeCountableConnectionFields class""" + self._subfields.extend(subfields) + return self + + def alias(self, alias: str) -> "ProductTypeCountableConnectionFields": + self._alias = alias + return self + + +class TranslatableItemConnectionFields(GraphQLField): + @classmethod + def page_info(cls) -> "PageInfoFields": + return PageInfoFields("page_info") + + @classmethod + def edges(cls) -> "TranslatableItemEdgeFields": + return TranslatableItemEdgeFields("edges") + + total_count: "TranslatableItemConnectionGraphQLField" = ( + TranslatableItemConnectionGraphQLField("totalCount") + ) + + def fields( + self, + *subfields: Union[ + TranslatableItemConnectionGraphQLField, + "PageInfoFields", + "TranslatableItemEdgeFields", + ] + ) -> "TranslatableItemConnectionFields": + """Subfields should come from the TranslatableItemConnectionFields class""" + self._subfields.extend(subfields) + return self + + def alias(self, alias: str) -> "TranslatableItemConnectionFields": + self._alias = alias + return self + + +class TranslatableItemEdgeFields(GraphQLField): + node: "TranslatableItemUnion" = TranslatableItemUnion("node") + cursor: "TranslatableItemEdgeGraphQLField" = TranslatableItemEdgeGraphQLField( + "cursor" + ) + + def fields( + self, + *subfields: Union[TranslatableItemEdgeGraphQLField, "TranslatableItemUnion"] + ) -> "TranslatableItemEdgeFields": + """Subfields should come from the TranslatableItemEdgeFields class""" + self._subfields.extend(subfields) + return self + + def alias(self, alias: str) -> "TranslatableItemEdgeFields": + self._alias = alias + return self + + +class UpdateMetadataFields(GraphQLField): + @classmethod + def metadata_errors(cls) -> "MetadataErrorFields": + return MetadataErrorFields("metadata_errors") + + @classmethod + def errors(cls) -> "MetadataErrorFields": + return MetadataErrorFields("errors") + + @classmethod + def item(cls) -> "ObjectWithMetadataInterface": + return ObjectWithMetadataInterface("item") + + def fields( + self, + *subfields: Union[ + UpdateMetadataGraphQLField, + "MetadataErrorFields", + "ObjectWithMetadataInterface", + ] + ) -> "UpdateMetadataFields": + """Subfields should come from the UpdateMetadataFields class""" + self._subfields.extend(subfields) + return self + + def alias(self, alias: str) -> "UpdateMetadataFields": + self._alias = alias + return self diff --git a/tests/main/clients/custom_query_builder/expected_client/custom_mutations.py b/tests/main/clients/custom_query_builder/expected_client/custom_mutations.py new file mode 100644 index 00000000..f4836ddb --- /dev/null +++ b/tests/main/clients/custom_query_builder/expected_client/custom_mutations.py @@ -0,0 +1,15 @@ +from typing import Any, Dict, Optional + +from .custom_fields import UpdateMetadataFields + + +class Mutation: + @classmethod + def update_metadata(cls, id: str) -> UpdateMetadataFields: + arguments: Dict[str, Dict[str, Any]] = {"id": {"type": "ID!", "value": id}} + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return UpdateMetadataFields( + field_name="updateMetadata", arguments=cleared_arguments + ) diff --git a/tests/main/clients/custom_query_builder/expected_client/custom_queries.py b/tests/main/clients/custom_query_builder/expected_client/custom_queries.py new file mode 100644 index 00000000..d14cf4db --- /dev/null +++ b/tests/main/clients/custom_query_builder/expected_client/custom_queries.py @@ -0,0 +1,55 @@ +from typing import Any, Dict, Optional + +from .custom_fields import ( + AppFields, + ProductCountableConnectionFields, + ProductTypeCountableConnectionFields, + TranslatableItemConnectionFields, +) + + +class Query: + @classmethod + def products( + cls, *, channel: Optional[str] = None, first: Optional[int] = None + ) -> ProductCountableConnectionFields: + arguments: Dict[str, Dict[str, Any]] = { + "channel": {"type": "String", "value": channel}, + "first": {"type": "Int", "value": first}, + } + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return ProductCountableConnectionFields( + field_name="products", arguments=cleared_arguments + ) + + @classmethod + def app(cls) -> AppFields: + return AppFields(field_name="app") + + @classmethod + def product_types(cls) -> ProductTypeCountableConnectionFields: + return ProductTypeCountableConnectionFields(field_name="productTypes") + + @classmethod + def translations( + cls, + *, + before: Optional[str] = None, + after: Optional[str] = None, + first: Optional[int] = None, + last: Optional[int] = None + ) -> TranslatableItemConnectionFields: + arguments: Dict[str, Dict[str, Any]] = { + "before": {"type": "String", "value": before}, + "after": {"type": "String", "value": after}, + "first": {"type": "Int", "value": first}, + "last": {"type": "Int", "value": last}, + } + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return TranslatableItemConnectionFields( + field_name="translations", arguments=cleared_arguments + ) diff --git a/tests/main/clients/custom_query_builder/expected_client/custom_typing_fields.py b/tests/main/clients/custom_query_builder/expected_client/custom_typing_fields.py new file mode 100644 index 00000000..91be62aa --- /dev/null +++ b/tests/main/clients/custom_query_builder/expected_client/custom_typing_fields.py @@ -0,0 +1,95 @@ +from .base_operation import GraphQLField + + +class ProductGraphQLField(GraphQLField): + def alias(self, alias: str) -> "ProductGraphQLField": + self._alias = alias + return self + + +class ProductCountableEdgeGraphQLField(GraphQLField): + def alias(self, alias: str) -> "ProductCountableEdgeGraphQLField": + self._alias = alias + return self + + +class ProductCountableConnectionGraphQLField(GraphQLField): + def alias(self, alias: str) -> "ProductCountableConnectionGraphQLField": + self._alias = alias + return self + + +class AppGraphQLField(GraphQLField): + def alias(self, alias: str) -> "AppGraphQLField": + self._alias = alias + return self + + +class ProductTypeCountableConnectionGraphQLField(GraphQLField): + def alias(self, alias: str) -> "ProductTypeCountableConnectionGraphQLField": + self._alias = alias + return self + + +class PageInfoGraphQLField(GraphQLField): + def alias(self, alias: str) -> "PageInfoGraphQLField": + self._alias = alias + return self + + +class ObjectWithMetadataGraphQLField(GraphQLField): + def alias(self, alias: str) -> "ObjectWithMetadataGraphQLField": + self._alias = alias + return self + + +class MetadataItemGraphQLField(GraphQLField): + def alias(self, alias: str) -> "MetadataItemGraphQLField": + self._alias = alias + return self + + +class UpdateMetadataGraphQLField(GraphQLField): + def alias(self, alias: str) -> "UpdateMetadataGraphQLField": + self._alias = alias + return self + + +class MetadataErrorGraphQLField(GraphQLField): + def alias(self, alias: str) -> "MetadataErrorGraphQLField": + self._alias = alias + return self + + +class TranslatableItemConnectionGraphQLField(GraphQLField): + def alias(self, alias: str) -> "TranslatableItemConnectionGraphQLField": + self._alias = alias + return self + + +class TranslatableItemEdgeGraphQLField(GraphQLField): + def alias(self, alias: str) -> "TranslatableItemEdgeGraphQLField": + self._alias = alias + return self + + +class TranslatableItemUnion(GraphQLField): + def on(self, type_name: str, *subfields: GraphQLField) -> "TranslatableItemUnion": + self._inline_fragments[type_name] = subfields + return self + + def alias(self, alias: str) -> "TranslatableItemUnion": + self._alias = alias + return self + + +class ProductTranslatableContentGraphQLField(GraphQLField): + def alias(self, alias: str) -> "ProductTranslatableContentGraphQLField": + self._alias = alias + return self + + +class CollectionTranslatableContentGraphQLField(GraphQLField): + def alias(self, alias: str) -> "CollectionTranslatableContentGraphQLField": + self._alias = alias + return self diff --git a/tests/main/clients/custom_query_builder/expected_client/enums.py b/tests/main/clients/custom_query_builder/expected_client/enums.py new file mode 100644 index 00000000..b6a853d5 --- /dev/null +++ b/tests/main/clients/custom_query_builder/expected_client/enums.py @@ -0,0 +1,9 @@ +from enum import Enum + + +class MetadataErrorCode(str, Enum): + GRAPHQL_ERROR = "GRAPHQL_ERROR" + INVALID = "INVALID" + NOT_FOUND = "NOT_FOUND" + REQUIRED = "REQUIRED" + NOT_UPDATED = "NOT_UPDATED" diff --git a/tests/main/clients/custom_query_builder/expected_client/exceptions.py b/tests/main/clients/custom_query_builder/expected_client/exceptions.py new file mode 100644 index 00000000..b34acfe1 --- /dev/null +++ b/tests/main/clients/custom_query_builder/expected_client/exceptions.py @@ -0,0 +1,83 @@ +from typing import Any, Dict, List, Optional, Union + +import httpx + + +class GraphQLClientError(Exception): + """Base exception.""" + + +class GraphQLClientHttpError(GraphQLClientError): + def __init__(self, status_code: int, response: httpx.Response) -> None: + self.status_code = status_code + self.response = response + + def __str__(self) -> str: + return f"HTTP status code: {self.status_code}" + + +class GraphQLClientInvalidResponseError(GraphQLClientError): + def __init__(self, response: httpx.Response) -> None: + self.response = response + + def __str__(self) -> str: + return "Invalid response format." + + +class GraphQLClientGraphQLError(GraphQLClientError): + def __init__( + self, + message: str, + locations: Optional[List[Dict[str, int]]] = None, + path: Optional[List[str]] = None, + extensions: Optional[Dict[str, object]] = None, + orginal: Optional[Dict[str, object]] = None, + ): + self.message = message + self.locations = locations + self.path = path + self.extensions = extensions + self.orginal = orginal + + def __str__(self) -> str: + return self.message + + @classmethod + def from_dict(cls, error: Dict[str, Any]) -> "GraphQLClientGraphQLError": + return cls( + message=error["message"], + locations=error.get("locations"), + path=error.get("path"), + extensions=error.get("extensions"), + orginal=error, + ) + + +class GraphQLClientGraphQLMultiError(GraphQLClientError): + def __init__( + self, + errors: List[GraphQLClientGraphQLError], + data: Optional[Dict[str, Any]] = None, + ): + self.errors = errors + self.data = data + + def __str__(self) -> str: + return "; ".join(str(e) for e in self.errors) + + @classmethod + def from_errors_dicts( + cls, errors_dicts: List[Dict[str, Any]], data: Optional[Dict[str, Any]] = None + ) -> "GraphQLClientGraphQLMultiError": + return cls( + errors=[GraphQLClientGraphQLError.from_dict(e) for e in errors_dicts], + data=data, + ) + + +class GraphQLClientInvalidMessageFormat(GraphQLClientError): + def __init__(self, message: Union[str, bytes]) -> None: + self.message = message + + def __str__(self) -> str: + return "Invalid message format." diff --git a/tests/main/clients/custom_query_builder/expected_client/input_types.py b/tests/main/clients/custom_query_builder/expected_client/input_types.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/main/clients/custom_query_builder/pyproject.toml b/tests/main/clients/custom_query_builder/pyproject.toml new file mode 100644 index 00000000..3cbb511e --- /dev/null +++ b/tests/main/clients/custom_query_builder/pyproject.toml @@ -0,0 +1,5 @@ +[tool.ariadne-codegen] +schema_path = "schema.graphql" +include_comments = "none" +target_package_name = "example_client" +enable_custom_operations = true diff --git a/tests/main/clients/custom_query_builder/schema.graphql b/tests/main/clients/custom_query_builder/schema.graphql new file mode 100644 index 00000000..4e873d7c --- /dev/null +++ b/tests/main/clients/custom_query_builder/schema.graphql @@ -0,0 +1,249 @@ +schema { + query: Query + mutation: Mutation +} + +type Query { + products(channel: String, first: Int): ProductCountableConnection + app: App + productTypes: ProductTypeCountableConnection + translations( + """ + Return the elements in the list that come before the specified cursor. + """ + before: String + + """ + Return the elements in the list that come after the specified cursor. + """ + after: String + + """ + Retrieve the first n elements from the list. Note that the system only allows fetching a maximum of 100 objects in a single query. + """ + first: Int + + """ + Retrieve the last n elements from the list. Note that the system only allows fetching a maximum of 100 objects in a single query. + """ + last: Int + ): TranslatableItemConnection +} + +type Mutation { + updateMetadata( + """ + ID or token (for Order and Checkout) of an object to update. + """ + id: ID! + ): UpdateMetadata +} + +type Product implements ObjectWithMetadata { + id: ID! + slug: String! + name: String! +} + +type ProductCountableEdge { + node: Product! + cursor: String! +} + +type ProductCountableConnection { + edges: [ProductCountableEdge!]! + pageInfo: PageInfo! + totalCount: Int +} + +type App { + id: ID! +} + +type ProductTypeCountableConnection { + pageInfo: PageInfo! +} + +type PageInfo { + hasNextPage: Boolean! + hasPreviousPage: Boolean! + startCursor: String + endCursor: String +} + +interface ObjectWithMetadata { + """ + List of private metadata items. Requires staff permissions to access. + """ + privateMetadata: [MetadataItem!]! + + """ + A single key from private metadata. Requires staff permissions to access. + + Tip: Use GraphQL aliases to fetch multiple keys. + """ + privateMetafield(key: String!): String + + """ + List of public metadata items. Can be accessed without permissions. + """ + metadata: [MetadataItem!]! + + """ + A single key from public metadata. + + Tip: Use GraphQL aliases to fetch multiple keys. + """ + metafield(key: String!): String +} + +type MetadataItem { + """ + Key of a metadata item. + """ + key: String! + + """ + Value of a metadata item. + """ + value: String! +} + +type UpdateMetadata { + metadataErrors: [MetadataError!]! + @deprecated( + reason: "This field will be removed in Saleor 4.0. Use `errors` field instead." + ) + errors: [MetadataError!]! + item: ObjectWithMetadata +} +type MetadataError { + """ + Name of a field that caused the error. A value of `null` indicates that the error isn't associated with a particular field. + """ + field: String + + """ + The error message. + """ + message: String + + """ + The error code. + """ + code: MetadataErrorCode! +} + +""" +An enumeration. +""" +enum MetadataErrorCode { + GRAPHQL_ERROR + INVALID + NOT_FOUND + REQUIRED + NOT_UPDATED +} + +type TranslatableItemConnection { + """ + Pagination data for this connection. + """ + pageInfo: PageInfo! + edges: [TranslatableItemEdge!]! + + """ + A total count of items in the collection. + """ + totalCount: Int +} + +type TranslatableItemEdge { + """ + The item at the end of the edge. + """ + node: TranslatableItem! + + """ + A cursor for use in pagination. + """ + cursor: String! +} + +union TranslatableItem = + ProductTranslatableContent + | CollectionTranslatableContent + +type ProductTranslatableContent @doc(category: "Products") { + """ + The ID of the product translatable content. + """ + id: ID! + + """ + The ID of the product to translate. + + Added in Saleor 3.14. + """ + productId: ID! + + """ + SEO title to translate. + """ + seoTitle: String + + """ + SEO description to translate. + """ + seoDescription: String + + """ + Product's name to translate. + """ + name: String! + + """ + Product's description to translate. + + Rich text format. For reference see https://editorjs.io/ + """ + description: JSONString +} + +type CollectionTranslatableContent @doc(category: "Products") { + """ + The ID of the collection translatable content. + """ + id: ID! + + """ + The ID of the collection to translate. + + Added in Saleor 3.14. + """ + collectionId: ID! + + """ + SEO title to translate. + """ + seoTitle: String + + """ + SEO description to translate. + """ + seoDescription: String + + """ + Collection's name to translate. + """ + name: String! + + """ + Collection's description to translate. + + Rich text format. For reference see https://editorjs.io/ + """ + description: JSONString +} + +scalar JSONString diff --git a/tests/main/custom_operation_builder/__init__.py b/tests/main/custom_operation_builder/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/main/custom_operation_builder/graphql_client/__init__.py b/tests/main/custom_operation_builder/graphql_client/__init__.py new file mode 100644 index 00000000..de6c91ca --- /dev/null +++ b/tests/main/custom_operation_builder/graphql_client/__init__.py @@ -0,0 +1,27 @@ +from .async_base_client import AsyncBaseClient +from .base_model import BaseModel, Upload +from .client import Client +from .enums import Role +from .exceptions import ( + GraphQLClientError, + GraphQLClientGraphQLError, + GraphQLClientGraphQLMultiError, + GraphQLClientHttpError, + GraphQLClientInvalidResponseError, +) +from .input_types import AddUserInput, UpdateUserInput + +__all__ = [ + "AddUserInput", + "AsyncBaseClient", + "BaseModel", + "Client", + "GraphQLClientError", + "GraphQLClientGraphQLError", + "GraphQLClientGraphQLMultiError", + "GraphQLClientHttpError", + "GraphQLClientInvalidResponseError", + "Role", + "UpdateUserInput", + "Upload", +] diff --git a/tests/main/custom_operation_builder/graphql_client/async_base_client.py b/tests/main/custom_operation_builder/graphql_client/async_base_client.py new file mode 100644 index 00000000..5358ced6 --- /dev/null +++ b/tests/main/custom_operation_builder/graphql_client/async_base_client.py @@ -0,0 +1,370 @@ +import enum +import json +from typing import IO, Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, cast +from uuid import uuid4 + +import httpx +from pydantic import BaseModel +from pydantic_core import to_jsonable_python + +from .base_model import UNSET, Upload +from .exceptions import ( + GraphQLClientGraphQLMultiError, + GraphQLClientHttpError, + GraphQLClientInvalidMessageFormat, + GraphQLClientInvalidResponseError, +) + +try: + from websockets.client import ( # type: ignore[import-not-found,unused-ignore] + WebSocketClientProtocol, + connect as ws_connect, + ) + from websockets.typing import ( # type: ignore[import-not-found,unused-ignore] + Data, + Origin, + Subprotocol, + ) +except ImportError: + from contextlib import asynccontextmanager + + @asynccontextmanager # type: ignore + async def ws_connect(*args, **kwargs): # pylint: disable=unused-argument + raise NotImplementedError("Subscriptions require 'websockets' package.") + yield # pylint: disable=unreachable + + WebSocketClientProtocol = Any # type: ignore[misc,assignment,unused-ignore] + Data = Any # type: ignore[misc,assignment,unused-ignore] + Origin = Any # type: ignore[misc,assignment,unused-ignore] + + def Subprotocol(*args, **kwargs): # type: ignore # pylint: disable=invalid-name + raise NotImplementedError("Subscriptions require 'websockets' package.") + + +Self = TypeVar("Self", bound="AsyncBaseClient") + +GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" + + +class GraphQLTransportWSMessageType(str, enum.Enum): + CONNECTION_INIT = "connection_init" + CONNECTION_ACK = "connection_ack" + PING = "ping" + PONG = "pong" + SUBSCRIBE = "subscribe" + NEXT = "next" + ERROR = "error" + COMPLETE = "complete" + + +class AsyncBaseClient: + def __init__( + self, + url: str = "", + headers: Optional[Dict[str, str]] = None, + http_client: Optional[httpx.AsyncClient] = None, + ws_url: str = "", + ws_headers: Optional[Dict[str, Any]] = None, + ws_origin: Optional[str] = None, + ws_connection_init_payload: Optional[Dict[str, Any]] = None, + ) -> None: + self.url = url + self.headers = headers + self.http_client = ( + http_client if http_client else httpx.AsyncClient(headers=headers) + ) + + self.ws_url = ws_url + self.ws_headers = ws_headers or {} + self.ws_origin = Origin(ws_origin) if ws_origin else None + self.ws_connection_init_payload = ws_connection_init_payload + + async def __aenter__(self: Self) -> Self: + return self + + async def __aexit__( + self, + exc_type: object, + exc_val: object, + exc_tb: object, + ) -> None: + await self.http_client.aclose() + + async def execute( + self, + query: str, + operation_name: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> httpx.Response: + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart( + query=query, + operation_name=operation_name, + variables=processed_variables, + files=files, + files_map=files_map, + **kwargs, + ) + + return await self._execute_json( + query=query, + operation_name=operation_name, + variables=processed_variables, + **kwargs, + ) + + def get_data(self, response: httpx.Response) -> Dict[str, Any]: + if not response.is_success: + raise GraphQLClientHttpError( + status_code=response.status_code, response=response + ) + + try: + response_json = response.json() + except ValueError as exc: + raise GraphQLClientInvalidResponseError(response=response) from exc + + if (not isinstance(response_json, dict)) or ( + "data" not in response_json and "errors" not in response_json + ): + raise GraphQLClientInvalidResponseError(response=response) + + data = response_json.get("data") + errors = response_json.get("errors") + + if errors: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=errors, data=data + ) + + return cast(Dict[str, Any], data) + + async def execute_ws( + self, + query: str, + operation_name: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> AsyncIterator[Dict[str, Any]]: + headers = self.ws_headers.copy() + headers.update(kwargs.get("extra_headers", {})) + + merged_kwargs: Dict[str, Any] = {"origin": self.ws_origin} + merged_kwargs.update(kwargs) + merged_kwargs["extra_headers"] = headers + + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + **merged_kwargs, + ) as websocket: + await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) + await self._send_subscribe( + websocket, + operation_id=operation_id, + query=query, + operation_name=operation_name, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message(message, websocket) + if data: + yield data + + def _process_variables( + self, variables: Optional[Dict[str, Any]] + ) -> Tuple[ + Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]] + ]: + if not variables: + return {}, {}, {} + + serializable_variables = self._convert_dict_to_json_serializable(variables) + return self._get_files_from_variables(serializable_variables) + + def _convert_dict_to_json_serializable( + self, dict_: Dict[str, Any] + ) -> Dict[str, Any]: + return { + key: self._convert_value(value) + for key, value in dict_.items() + if value is not UNSET + } + + def _convert_value(self, value: Any) -> Any: + if isinstance(value, BaseModel): + return value.model_dump(by_alias=True, exclude_unset=True) + if isinstance(value, list): + return [self._convert_value(item) for item in value] + return value + + def _get_files_from_variables( + self, variables: Dict[str, Any] + ) -> Tuple[ + Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]] + ]: + files_map: Dict[str, List[str]] = {} + files_list: List[Upload] = [] + + def separate_files(path: str, obj: Any) -> Any: + if isinstance(obj, list): + nulled_list = [] + for index, value in enumerate(obj): + value = separate_files(f"{path}.{index}", value) + nulled_list.append(value) + return nulled_list + + if isinstance(obj, dict): + nulled_dict = {} + for key, value in obj.items(): + value = separate_files(f"{path}.{key}", value) + nulled_dict[key] = value + return nulled_dict + + if isinstance(obj, Upload): + if obj in files_list: + file_index = files_list.index(obj) + files_map[str(file_index)].append(path) + else: + file_index = len(files_list) + files_list.append(obj) + files_map[str(file_index)] = [path] + return None + + return obj + + nulled_variables = separate_files("variables", variables) + files: Dict[str, Tuple[str, IO[bytes], str]] = { + str(i): (file_.filename, cast(IO[bytes], file_.content), file_.content_type) + for i, file_ in enumerate(files_list) + } + return nulled_variables, files, files_map + + async def _execute_multipart( + self, + query: str, + operation_name: Optional[str], + variables: Dict[str, Any], + files: Dict[str, Tuple[str, IO[bytes], str]], + files_map: Dict[str, List[str]], + **kwargs: Any, + ) -> httpx.Response: + data = { + "operations": json.dumps( + { + "query": query, + "operationName": operation_name, + "variables": variables, + }, + default=to_jsonable_python, + ), + "map": json.dumps(files_map, default=to_jsonable_python), + } + + return await self.http_client.post( + url=self.url, data=data, files=files, **kwargs + ) + + async def _execute_json( + self, + query: str, + operation_name: Optional[str], + variables: Dict[str, Any], + **kwargs: Any, + ) -> httpx.Response: + headers: Dict[str, str] = {"Content-Type": "application/json"} + headers.update(kwargs.get("headers", {})) + + merged_kwargs: Dict[str, Any] = kwargs.copy() + merged_kwargs["headers"] = headers + + return await self.http_client.post( + url=self.url, + content=json.dumps( + { + "query": query, + "operationName": operation_name, + "variables": variables, + }, + default=to_jsonable_python, + ), + **merged_kwargs, + ) + + async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: + payload: Dict[str, Any] = { + "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value + } + if self.ws_connection_init_payload: + payload["payload"] = self.ws_connection_init_payload + await websocket.send(json.dumps(payload)) + + async def _send_subscribe( + self, + websocket: WebSocketClientProtocol, + operation_id: str, + query: str, + operation_name: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + ) -> None: + payload: Dict[str, Any] = { + "id": operation_id, + "type": GraphQLTransportWSMessageType.SUBSCRIBE.value, + "payload": {"query": query, "operationName": operation_name}, + } + if variables: + payload["payload"]["variables"] = self._convert_dict_to_json_serializable( + variables + ) + await websocket.send(json.dumps(payload)) + + async def _handle_ws_message( + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: Optional[GraphQLTransportWSMessageType] = None, + ) -> Optional[Dict[str, Any]]: + try: + message_dict = json.loads(message) + except json.JSONDecodeError as exc: + raise GraphQLClientInvalidMessageFormat(message=message) from exc + + type_ = message_dict.get("type") + payload = message_dict.get("payload", {}) + + if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: + raise GraphQLClientInvalidMessageFormat(message=message) + + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message received. Expected: {expected_type.value}" + ) + + if type_ == GraphQLTransportWSMessageType.NEXT: + if "data" not in payload: + raise GraphQLClientInvalidMessageFormat(message=message) + return cast(Dict[str, Any], payload["data"]) + + if type_ == GraphQLTransportWSMessageType.COMPLETE: + await websocket.close() + elif type_ == GraphQLTransportWSMessageType.PING: + await websocket.send( + json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) + ) + elif type_ == GraphQLTransportWSMessageType.ERROR: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=payload, data=message_dict + ) + + return None diff --git a/tests/main/custom_operation_builder/graphql_client/base_model.py b/tests/main/custom_operation_builder/graphql_client/base_model.py new file mode 100644 index 00000000..ccde3975 --- /dev/null +++ b/tests/main/custom_operation_builder/graphql_client/base_model.py @@ -0,0 +1,27 @@ +from io import IOBase + +from pydantic import BaseModel as PydanticBaseModel, ConfigDict + + +class UnsetType: + def __bool__(self) -> bool: + return False + + +UNSET = UnsetType() + + +class BaseModel(PydanticBaseModel): + model_config = ConfigDict( + populate_by_name=True, + validate_assignment=True, + arbitrary_types_allowed=True, + protected_namespaces=(), + ) + + +class Upload: + def __init__(self, filename: str, content: IOBase, content_type: str): + self.filename = filename + self.content = content + self.content_type = content_type diff --git a/tests/main/custom_operation_builder/graphql_client/base_operation.py b/tests/main/custom_operation_builder/graphql_client/base_operation.py new file mode 100644 index 00000000..9f9c7660 --- /dev/null +++ b/tests/main/custom_operation_builder/graphql_client/base_operation.py @@ -0,0 +1,156 @@ +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +from graphql import ( + ArgumentNode, + FieldNode, + InlineFragmentNode, + NamedTypeNode, + NameNode, + SelectionSetNode, + VariableNode, +) + + +class GraphQLArgument: + """ + Represents a GraphQL argument and allows conversion to an AST structure. + """ + + def __init__(self, argument_name: str, argument_value: Any) -> None: + self._name = argument_name + self._value = argument_value + + def to_ast(self) -> ArgumentNode: + """Converts the argument to an ArgumentNode AST object.""" + return ArgumentNode( + name=NameNode(value=self._name), + value=VariableNode(name=NameNode(value=self._value)), + ) + + +class GraphQLField: + """ + Represents a GraphQL field with its name, arguments, subfields, alias, + and inline fragments. + + Attributes: + formatted_variables (Dict[str, Dict[str, Any]]): The formatted arguments + of the GraphQL field. + """ + + def __init__( + self, field_name: str, arguments: Optional[Dict[str, Dict[str, Any]]] = None + ) -> None: + self._field_name = field_name + self._variables = arguments or {} + self.formatted_variables: Dict[str, Dict[str, Any]] = {} + self._subfields: List[GraphQLField] = [] + self._alias: Optional[str] = None + self._inline_fragments: Dict[str, Tuple[GraphQLField, ...]] = {} + + def alias(self, alias: str) -> "GraphQLField": + """Sets an alias for the GraphQL field and returns the instance.""" + self._alias = alias + return self + + def _build_field_name(self) -> str: + """Builds the field name, including the alias if present.""" + return f"{self._alias}: {self._field_name}" if self._alias else self._field_name + + def _build_selections( + self, idx: int, used_names: Set[str] + ) -> List[Union[FieldNode, InlineFragmentNode]]: + """Builds the selection set for the current GraphQL field, + including subfields and inline fragments.""" + # Create selections from subfields + selections: List[Union[FieldNode, InlineFragmentNode]] = [ + subfield.to_ast(idx, used_names) for subfield in self._subfields + ] + + # Add inline fragments + for name, subfields in self._inline_fragments.items(): + selections.append( + InlineFragmentNode( + type_condition=NamedTypeNode(name=NameNode(value=name)), + selection_set=SelectionSetNode( + selections=[ + subfield.to_ast(idx, used_names) for subfield in subfields + ] + ), + ) + ) + + return selections + + def _format_variable_name( + self, idx: int, var_name: str, used_names: Set[str] + ) -> str: + """Generates a unique variable name by appending an index and, + if necessary, an additional counter to avoid duplicates.""" + base_name = f"{var_name}_{idx}" + unique_name = base_name + counter = 1 + + # Ensure the generated name is unique + while unique_name in used_names: + unique_name = f"{base_name}_{counter}" + counter += 1 + + # Add the unique name to the set of used names + used_names.add(unique_name) + + return unique_name + + def _collect_all_variables(self, idx: int, used_names: Set[str]) -> None: + """ + Collects and formats all variables for the current GraphQL field, + ensuring unique names. + """ + self.formatted_variables = {} + + for k, v in self._variables.items(): + unique_name = self._format_variable_name(idx, k, used_names) + self.formatted_variables[unique_name] = { + "name": k, + "type": v["type"], + "value": v["value"], + } + + def to_ast(self, idx: int, used_names: Optional[Set[str]] = None) -> FieldNode: + """Converts the current GraphQL field to an AST (Abstract Syntax Tree) node.""" + if used_names is None: + used_names = set() + + self._collect_all_variables(idx, used_names) + + return FieldNode( + name=NameNode(value=self._build_field_name()), + arguments=[ + GraphQLArgument(v["name"], k).to_ast() + for k, v in self.formatted_variables.items() + ], + selection_set=( + SelectionSetNode(selections=self._build_selections(idx, used_names)) + if self._subfields or self._inline_fragments + else None + ), + ) + + def get_formatted_variables(self) -> Dict[str, Dict[str, Any]]: + """ + Retrieves all formatted variables for the current GraphQL field, + including those from subfields and inline fragments. + """ + formatted_variables = self.formatted_variables.copy() + + # Collect variables from subfields + for subfield in self._subfields: + subfield.get_formatted_variables() + formatted_variables.update(subfield.formatted_variables) + + # Collect variables from inline fragments + for subfields in self._inline_fragments.values(): + for subfield in subfields: + subfield.get_formatted_variables() + formatted_variables.update(subfield.formatted_variables) + return formatted_variables diff --git a/tests/main/custom_operation_builder/graphql_client/client.py b/tests/main/custom_operation_builder/graphql_client/client.py new file mode 100644 index 00000000..c89fa2c6 --- /dev/null +++ b/tests/main/custom_operation_builder/graphql_client/client.py @@ -0,0 +1,107 @@ +from typing import Any, Dict, List, Tuple + +from graphql import ( + DocumentNode, + NamedTypeNode, + NameNode, + OperationDefinitionNode, + OperationType, + SelectionNode, + SelectionSetNode, + VariableDefinitionNode, + VariableNode, + print_ast, +) + +from .async_base_client import AsyncBaseClient +from .base_operation import GraphQLField + + +def gql(q: str) -> str: + return q + + +class Client(AsyncBaseClient): + async def execute_custom_operation( + self, *fields: GraphQLField, operation_type: OperationType, operation_name: str + ) -> Dict[str, Any]: + selections = self._build_selection_set(fields) + combined_variables = self._combine_variables(fields) + variable_definitions = self._build_variable_definitions( + combined_variables["types"] + ) + operation_ast = self._build_operation_ast( + selections, operation_type, operation_name, variable_definitions + ) + response = await self.execute( + print_ast(operation_ast), + variables=combined_variables["values"], + operation_name=operation_name, + ) + return self.get_data(response) + + def _combine_variables( + self, fields: Tuple[GraphQLField, ...] + ) -> Dict[str, Dict[str, Any]]: + variables_types_combined = {} + processed_variables_combined = {} + for field in fields: + formatted_variables = field.get_formatted_variables() + variables_types_combined.update( + {k: v["type"] for k, v in formatted_variables.items()} + ) + processed_variables_combined.update( + {k: v["value"] for k, v in formatted_variables.items()} + ) + return { + "types": variables_types_combined, + "values": processed_variables_combined, + } + + def _build_variable_definitions( + self, variables_types_combined: Dict[str, str] + ) -> List[VariableDefinitionNode]: + return [ + VariableDefinitionNode( + variable=VariableNode(name=NameNode(value=var_name)), + type=NamedTypeNode(name=NameNode(value=var_value)), + ) + for var_name, var_value in variables_types_combined.items() + ] + + def _build_operation_ast( + self, + selections: List[SelectionNode], + operation_type: OperationType, + operation_name: str, + variable_definitions: List[VariableDefinitionNode], + ) -> DocumentNode: + return DocumentNode( + definitions=[ + OperationDefinitionNode( + operation=operation_type, + name=NameNode(value=operation_name), + variable_definitions=variable_definitions, + selection_set=SelectionSetNode(selections=selections), + ) + ] + ) + + def _build_selection_set( + self, fields: Tuple[GraphQLField, ...] + ) -> List[SelectionNode]: + return [field.to_ast(idx) for idx, field in enumerate(fields)] + + async def query(self, *fields: GraphQLField, operation_name: str) -> Dict[str, Any]: + return await self.execute_custom_operation( + *fields, operation_type=OperationType.QUERY, operation_name=operation_name + ) + + async def mutation( + self, *fields: GraphQLField, operation_name: str + ) -> Dict[str, Any]: + return await self.execute_custom_operation( + *fields, + operation_type=OperationType.MUTATION, + operation_name=operation_name + ) diff --git a/tests/main/custom_operation_builder/graphql_client/custom_fields.py b/tests/main/custom_operation_builder/graphql_client/custom_fields.py new file mode 100644 index 00000000..d5f29221 --- /dev/null +++ b/tests/main/custom_operation_builder/graphql_client/custom_fields.py @@ -0,0 +1,160 @@ +from typing import Any, Dict, Optional, Union + +from .base_operation import GraphQLField +from .custom_typing_fields import ( + AdminGraphQLField, + GuestGraphQLField, + PersonInterfaceGraphQLField, + PostGraphQLField, + UserGraphQLField, +) + + +class AdminFields(GraphQLField): + id: "AdminGraphQLField" = AdminGraphQLField("id") + name: "AdminGraphQLField" = AdminGraphQLField("name") + privileges: "AdminGraphQLField" = AdminGraphQLField("privileges") + email: "AdminGraphQLField" = AdminGraphQLField("email") + created_at: "AdminGraphQLField" = AdminGraphQLField("createdAt") + + @classmethod + def metafield(cls, key: str) -> "AdminGraphQLField": + arguments: Dict[str, Dict[str, Any]] = { + "key": {"type": "String!", "value": key} + } + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return AdminGraphQLField("metafield", arguments=cleared_arguments) + + @classmethod + def custom_field(cls, *, key: Optional[str] = None) -> "AdminGraphQLField": + arguments: Dict[str, Dict[str, Any]] = {"key": {"type": "String", "value": key}} + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return AdminGraphQLField("custom_field", arguments=cleared_arguments) + + def fields(self, *subfields: AdminGraphQLField) -> "AdminFields": + """Subfields should come from the AdminFields class""" + self._subfields.extend(subfields) + return self + + def alias(self, alias: str) -> "AdminFields": + self._alias = alias + return self + + +class GuestFields(GraphQLField): + id: "GuestGraphQLField" = GuestGraphQLField("id") + name: "GuestGraphQLField" = GuestGraphQLField("name") + visit_count: "GuestGraphQLField" = GuestGraphQLField("visitCount") + email: "GuestGraphQLField" = GuestGraphQLField("email") + created_at: "GuestGraphQLField" = GuestGraphQLField("createdAt") + + @classmethod + def metafield(cls, key: str) -> "GuestGraphQLField": + arguments: Dict[str, Dict[str, Any]] = { + "key": {"type": "String!", "value": key} + } + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return GuestGraphQLField("metafield", arguments=cleared_arguments) + + def fields(self, *subfields: GuestGraphQLField) -> "GuestFields": + """Subfields should come from the GuestFields class""" + self._subfields.extend(subfields) + return self + + def alias(self, alias: str) -> "GuestFields": + self._alias = alias + return self + + +class PersonInterfaceInterface(GraphQLField): + id: "PersonInterfaceGraphQLField" = PersonInterfaceGraphQLField("id") + name: "PersonInterfaceGraphQLField" = PersonInterfaceGraphQLField("name") + email: "PersonInterfaceGraphQLField" = PersonInterfaceGraphQLField("email") + + @classmethod + def metafield(cls, key: str) -> "PersonInterfaceGraphQLField": + arguments: Dict[str, Dict[str, Any]] = { + "key": {"type": "String!", "value": key} + } + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return PersonInterfaceGraphQLField("metafield", arguments=cleared_arguments) + + def fields( + self, *subfields: PersonInterfaceGraphQLField + ) -> "PersonInterfaceInterface": + """Subfields should come from the PersonInterfaceInterface class""" + self._subfields.extend(subfields) + return self + + def alias(self, alias: str) -> "PersonInterfaceInterface": + self._alias = alias + return self + + def on( + self, type_name: str, *subfields: GraphQLField + ) -> "PersonInterfaceInterface": + self._inline_fragments[type_name] = subfields + return self + + +class PostFields(GraphQLField): + id: "PostGraphQLField" = PostGraphQLField("id") + title: "PostGraphQLField" = PostGraphQLField("title") + content: "PostGraphQLField" = PostGraphQLField("content") + + @classmethod + def author(cls) -> "PersonInterfaceInterface": + return PersonInterfaceInterface("author") + + published_at: "PostGraphQLField" = PostGraphQLField("publishedAt") + + def fields( + self, *subfields: Union[PostGraphQLField, "PersonInterfaceInterface"] + ) -> "PostFields": + """Subfields should come from the PostFields class""" + self._subfields.extend(subfields) + return self + + def alias(self, alias: str) -> "PostFields": + self._alias = alias + return self + + +class UserFields(GraphQLField): + id: "UserGraphQLField" = UserGraphQLField("id") + name: "UserGraphQLField" = UserGraphQLField("name") + age: "UserGraphQLField" = UserGraphQLField("age") + email: "UserGraphQLField" = UserGraphQLField("email") + role: "UserGraphQLField" = UserGraphQLField("role") + created_at: "UserGraphQLField" = UserGraphQLField("createdAt") + + @classmethod + def friends(cls) -> "UserFields": + return UserFields("friends") + + @classmethod + def metafield(cls, key: str) -> "UserGraphQLField": + arguments: Dict[str, Dict[str, Any]] = { + "key": {"type": "String!", "value": key} + } + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return UserGraphQLField("metafield", arguments=cleared_arguments) + + def fields(self, *subfields: Union[UserGraphQLField, "UserFields"]) -> "UserFields": + """Subfields should come from the UserFields class""" + self._subfields.extend(subfields) + return self + + def alias(self, alias: str) -> "UserFields": + self._alias = alias + return self diff --git a/tests/main/custom_operation_builder/graphql_client/custom_mutations.py b/tests/main/custom_operation_builder/graphql_client/custom_mutations.py new file mode 100644 index 00000000..6b271253 --- /dev/null +++ b/tests/main/custom_operation_builder/graphql_client/custom_mutations.py @@ -0,0 +1,82 @@ +from typing import Any, Dict, Optional + +from .custom_fields import PostFields, UserFields +from .input_types import AddUserInput, UpdateUserInput + + +class Mutation: + @classmethod + def add_user(cls, user_input: AddUserInput) -> UserFields: + arguments: Dict[str, Dict[str, Any]] = { + "user_input": {"type": "AddUserInput!", "value": user_input} + } + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return UserFields(field_name="addUser", arguments=cleared_arguments) + + @classmethod + def update_user(cls, user_id: str, user_input: UpdateUserInput) -> UserFields: + arguments: Dict[str, Dict[str, Any]] = { + "user_id": {"type": "ID!", "value": user_id}, + "user_input": {"type": "UpdateUserInput!", "value": user_input}, + } + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return UserFields(field_name="updateUser", arguments=cleared_arguments) + + @classmethod + def delete_user(cls, user_id: str) -> UserFields: + arguments: Dict[str, Dict[str, Any]] = { + "user_id": {"type": "ID!", "value": user_id} + } + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return UserFields(field_name="deleteUser", arguments=cleared_arguments) + + @classmethod + def add_post( + cls, title: str, content: str, author_id: str, published_at: str + ) -> PostFields: + arguments: Dict[str, Dict[str, Any]] = { + "title": {"type": "String!", "value": title}, + "content": {"type": "String!", "value": content}, + "authorId": {"type": "ID!", "value": author_id}, + "publishedAt": {"type": "String!", "value": published_at}, + } + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return PostFields(field_name="addPost", arguments=cleared_arguments) + + @classmethod + def update_post( + cls, + post_id: str, + *, + title: Optional[str] = None, + content: Optional[str] = None, + published_at: Optional[str] = None + ) -> PostFields: + arguments: Dict[str, Dict[str, Any]] = { + "post_id": {"type": "ID!", "value": post_id}, + "title": {"type": "String", "value": title}, + "content": {"type": "String", "value": content}, + "publishedAt": {"type": "String", "value": published_at}, + } + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return PostFields(field_name="updatePost", arguments=cleared_arguments) + + @classmethod + def delete_post(cls, post_id: str) -> PostFields: + arguments: Dict[str, Dict[str, Any]] = { + "post_id": {"type": "ID!", "value": post_id} + } + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return PostFields(field_name="deletePost", arguments=cleared_arguments) diff --git a/tests/main/custom_operation_builder/graphql_client/custom_queries.py b/tests/main/custom_operation_builder/graphql_client/custom_queries.py new file mode 100644 index 00000000..353ed882 --- /dev/null +++ b/tests/main/custom_operation_builder/graphql_client/custom_queries.py @@ -0,0 +1,64 @@ +from typing import Any, Dict, Optional + +from .custom_fields import PersonInterfaceInterface, PostFields, UserFields +from .custom_typing_fields import GraphQLField, SearchResultUnion + + +class Query: + @classmethod + def hello(cls) -> GraphQLField: + return GraphQLField(field_name="hello") + + @classmethod + def greeting(cls, *, name: Optional[str] = None) -> GraphQLField: + arguments: Dict[str, Dict[str, Any]] = { + "name": {"type": "String", "value": name} + } + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return GraphQLField(field_name="greeting", arguments=cleared_arguments) + + @classmethod + def user(cls, user_id: str) -> UserFields: + arguments: Dict[str, Dict[str, Any]] = { + "user_id": {"type": "ID!", "value": user_id} + } + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return UserFields(field_name="user", arguments=cleared_arguments) + + @classmethod + def users(cls) -> UserFields: + return UserFields(field_name="users") + + @classmethod + def search(cls, text: str) -> SearchResultUnion: + arguments: Dict[str, Dict[str, Any]] = { + "text": {"type": "String!", "value": text} + } + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return SearchResultUnion(field_name="search", arguments=cleared_arguments) + + @classmethod + def posts(cls) -> PostFields: + return PostFields(field_name="posts") + + @classmethod + def person(cls, person_id: str) -> PersonInterfaceInterface: + arguments: Dict[str, Dict[str, Any]] = { + "person_id": {"type": "ID!", "value": person_id} + } + cleared_arguments = { + key: value for key, value in arguments.items() if value["value"] is not None + } + return PersonInterfaceInterface( + field_name="person", arguments=cleared_arguments + ) + + @classmethod + def people(cls) -> PersonInterfaceInterface: + return PersonInterfaceInterface(field_name="people") diff --git a/tests/main/custom_operation_builder/graphql_client/custom_typing_fields.py b/tests/main/custom_operation_builder/graphql_client/custom_typing_fields.py new file mode 100644 index 00000000..b4dd2471 --- /dev/null +++ b/tests/main/custom_operation_builder/graphql_client/custom_typing_fields.py @@ -0,0 +1,41 @@ +from .base_operation import GraphQLField + + +class PersonInterfaceGraphQLField(GraphQLField): + def alias(self, alias: str) -> "PersonInterfaceGraphQLField": + self._alias = alias + return self + + +class UserGraphQLField(GraphQLField): + def alias(self, alias: str) -> "UserGraphQLField": + self._alias = alias + return self + + +class AdminGraphQLField(GraphQLField): + def alias(self, alias: str) -> "AdminGraphQLField": + self._alias = alias + return self + + +class GuestGraphQLField(GraphQLField): + def alias(self, alias: str) -> "GuestGraphQLField": + self._alias = alias + return self + + +class PostGraphQLField(GraphQLField): + def alias(self, alias: str) -> "PostGraphQLField": + self._alias = alias + return self + + +class SearchResultUnion(GraphQLField): + def on(self, type_name: str, *subfields: GraphQLField) -> "SearchResultUnion": + self._inline_fragments[type_name] = subfields + return self + + def alias(self, alias: str) -> "SearchResultUnion": + self._alias = alias + return self diff --git a/tests/main/custom_operation_builder/graphql_client/enums.py b/tests/main/custom_operation_builder/graphql_client/enums.py new file mode 100644 index 00000000..72d8ca4b --- /dev/null +++ b/tests/main/custom_operation_builder/graphql_client/enums.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class Role(str, Enum): + ADMIN = "ADMIN" + USER = "USER" diff --git a/tests/main/custom_operation_builder/graphql_client/exceptions.py b/tests/main/custom_operation_builder/graphql_client/exceptions.py new file mode 100644 index 00000000..b34acfe1 --- /dev/null +++ b/tests/main/custom_operation_builder/graphql_client/exceptions.py @@ -0,0 +1,83 @@ +from typing import Any, Dict, List, Optional, Union + +import httpx + + +class GraphQLClientError(Exception): + """Base exception.""" + + +class GraphQLClientHttpError(GraphQLClientError): + def __init__(self, status_code: int, response: httpx.Response) -> None: + self.status_code = status_code + self.response = response + + def __str__(self) -> str: + return f"HTTP status code: {self.status_code}" + + +class GraphQLClientInvalidResponseError(GraphQLClientError): + def __init__(self, response: httpx.Response) -> None: + self.response = response + + def __str__(self) -> str: + return "Invalid response format." + + +class GraphQLClientGraphQLError(GraphQLClientError): + def __init__( + self, + message: str, + locations: Optional[List[Dict[str, int]]] = None, + path: Optional[List[str]] = None, + extensions: Optional[Dict[str, object]] = None, + orginal: Optional[Dict[str, object]] = None, + ): + self.message = message + self.locations = locations + self.path = path + self.extensions = extensions + self.orginal = orginal + + def __str__(self) -> str: + return self.message + + @classmethod + def from_dict(cls, error: Dict[str, Any]) -> "GraphQLClientGraphQLError": + return cls( + message=error["message"], + locations=error.get("locations"), + path=error.get("path"), + extensions=error.get("extensions"), + orginal=error, + ) + + +class GraphQLClientGraphQLMultiError(GraphQLClientError): + def __init__( + self, + errors: List[GraphQLClientGraphQLError], + data: Optional[Dict[str, Any]] = None, + ): + self.errors = errors + self.data = data + + def __str__(self) -> str: + return "; ".join(str(e) for e in self.errors) + + @classmethod + def from_errors_dicts( + cls, errors_dicts: List[Dict[str, Any]], data: Optional[Dict[str, Any]] = None + ) -> "GraphQLClientGraphQLMultiError": + return cls( + errors=[GraphQLClientGraphQLError.from_dict(e) for e in errors_dicts], + data=data, + ) + + +class GraphQLClientInvalidMessageFormat(GraphQLClientError): + def __init__(self, message: Union[str, bytes]) -> None: + self.message = message + + def __str__(self) -> str: + return "Invalid message format." diff --git a/tests/main/custom_operation_builder/graphql_client/input_types.py b/tests/main/custom_operation_builder/graphql_client/input_types.py new file mode 100644 index 00000000..7c78cfaf --- /dev/null +++ b/tests/main/custom_operation_builder/graphql_client/input_types.py @@ -0,0 +1,22 @@ +from typing import Optional + +from pydantic import Field + +from .base_model import BaseModel +from .enums import Role + + +class AddUserInput(BaseModel): + name: str + age: int + email: str + role: Role + created_at: str = Field(alias="createdAt") + + +class UpdateUserInput(BaseModel): + name: Optional[str] = None + age: Optional[int] = None + email: Optional[str] = None + role: Optional[Role] = None + created_at: Optional[str] = Field(alias="createdAt", default=None) diff --git a/tests/main/custom_operation_builder/schema.graphql b/tests/main/custom_operation_builder/schema.graphql new file mode 100644 index 00000000..74be4fd9 --- /dev/null +++ b/tests/main/custom_operation_builder/schema.graphql @@ -0,0 +1,102 @@ +enum Role { + ADMIN + USER +} + +input AddUserInput { + name: String! + age: Int! + email: String! + role: Role! + createdAt: String! +} + +input UpdateUserInput { + name: String + age: Int + email: String + role: Role + createdAt: String +} + +interface PersonInterface { + id: ID! + name: String! + email: String! + metafield(key: String!): String +} + +type User implements PersonInterface { + id: ID! + name: String! + age: Int + email: String! + role: Role! + createdAt: String + friends: [User] + metafield(key: String!): String +} + +type Admin implements PersonInterface { + id: ID! + name: String! + privileges: [String!]! + email: String! + createdAt: String + metafield(key: String!): String + customField(key: String): String +} + +type Guest implements PersonInterface { + id: ID! + name: String! + visitCount: Int + email: String! + createdAt: String + metafield(key: String!): String +} + +type Post { + id: ID! + title: String! + content: String! + author: PersonInterface + publishedAt: String +} + +type Query { + hello: String + greeting(name: String): String + user(user_id: ID!): User + users: [User] + search(text: String!): [SearchResult] + posts: [Post] + person(person_id: ID!): PersonInterface + people: [PersonInterface] +} + +union SearchResult = User | Admin | Guest + +type Mutation { + addUser(user_input: AddUserInput!): User + updateUser(user_id: ID!, user_input: UpdateUserInput!): User + deleteUser(user_id: ID!): User + addPost( + title: String! + content: String! + authorId: ID! + publishedAt: String! + ): Post + updatePost( + post_id: ID! + title: String + content: String + publishedAt: String + ): Post + deletePost(post_id: ID!): Post +} + +schema { + query: Query + mutation: Mutation +} diff --git a/tests/main/custom_operation_builder/test_operation_build.py b/tests/main/custom_operation_builder/test_operation_build.py new file mode 100644 index 00000000..19e4c83f --- /dev/null +++ b/tests/main/custom_operation_builder/test_operation_build.py @@ -0,0 +1,565 @@ +from graphql import print_ast + +from .graphql_client.custom_fields import ( + AdminFields, + GuestFields, + PersonInterfaceInterface, + PostFields, + UserFields, +) +from .graphql_client.custom_mutations import Mutation +from .graphql_client.custom_queries import Query +from .graphql_client.enums import Role +from .graphql_client.input_types import AddUserInput, UpdateUserInput + + +def test_simple_hello(): + built_query = print_ast(Query.hello().to_ast(0)) + expected_query = "hello" + assert built_query == expected_query + + +def test_alias(): + built_query = print_ast(Query.hello().alias("aliased_hello").to_ast(0)) + expected_query = "aliased_hello: hello" + assert built_query == expected_query + + +def test_greeting_with_name(): + query = Query.greeting(name="Alice") + expected_query = "greeting(name: $name_0)" + + built_query = print_ast(query.to_ast(0)) + + assert built_query == expected_query + assert query.get_formatted_variables() == { + "name_0": {"name": "name", "type": "String", "value": "Alice"} + } + + +def test_user_by_id(): + query = Query.user(user_id="1").fields( + UserFields.id, + UserFields.name, + UserFields.age, + UserFields.email, + ) + expected_query = "user(user_id: $user_id_0) {\n id\n name\n age\n email\n}" + + built_query = print_ast(query.to_ast(0)) + + assert built_query == expected_query + assert query.get_formatted_variables() == { + "user_id_0": {"name": "user_id", "type": "ID!", "value": "1"} + } + + +def test_all_users(): + query = Query.users().fields( + UserFields.id, + UserFields.name, + UserFields.age, + UserFields.email, + ) + expected_query = "users {\n id\n name\n age\n email\n}" + + built_query = print_ast(query.to_ast(0)) + + assert built_query == expected_query + assert not query.get_formatted_variables() + + +def test_user_with_friends(): + query = Query.user(user_id="1").fields( + UserFields.id, + UserFields.name, + UserFields.age, + UserFields.email, + UserFields.friends().fields( + UserFields.id, + UserFields.name, + ), + UserFields.created_at, + ) + expected_query = ( + "user(user_id: $user_id_0) {\n" + " id\n" + " name\n" + " age\n" + " email\n" + " friends {\n" + " id\n" + " name\n" + " }\n" + " createdAt\n" + "}" + ) + + built_query = print_ast(query.to_ast(0)) + + assert built_query == expected_query + assert query.get_formatted_variables() == { + "user_id_0": {"name": "user_id", "type": "ID!", "value": "1"} + } + + +def test_search_example(): + query = ( + Query.search(text="example") + .on( + "User", + UserFields.id, + UserFields.name, + UserFields.email, + UserFields.created_at, + ) + .on( + "Admin", + AdminFields.id, + AdminFields.name, + AdminFields.privileges, + AdminFields.created_at, + ) + .on( + "Guest", + GuestFields.id, + GuestFields.name, + GuestFields.visit_count, + GuestFields.created_at, + ) + ) + expected_query = ( + "search(text: $text_0) {\n" + " ... on User {\n" + " id\n" + " name\n" + " email\n" + " createdAt\n" + " }\n" + " ... on Admin {\n" + " id\n" + " name\n" + " privileges\n" + " createdAt\n" + " }\n" + " ... on Guest {\n" + " id\n" + " name\n" + " visitCount\n" + " createdAt\n" + " }\n" + "}" + ) + + built_query = print_ast(query.to_ast(0)) + + assert built_query == expected_query + assert query.get_formatted_variables() == { + "text_0": {"name": "text", "type": "String!", "value": "example"} + } + + +def test_posts_with_authors(): + query = Query.posts().fields( + PostFields.id, + PostFields.title, + PostFields.content, + PostFields.author().fields( + PersonInterfaceInterface.id, + PersonInterfaceInterface.name, + PersonInterfaceInterface.email, + ), + PostFields.published_at, + ) + expected_query = ( + "posts {\n" + " id\n" + " title\n" + " content\n" + " author {\n" + " id\n" + " name\n" + " email\n" + " }\n" + " publishedAt\n" + "}" + ) + + built_query = print_ast(query.to_ast(0)) + + assert built_query == expected_query + assert not query.get_formatted_variables() + + +def test_get_person(): + query = ( + Query.person(person_id="1") + .fields( + PersonInterfaceInterface.id, + PersonInterfaceInterface.name, + PersonInterfaceInterface.email, + ) + .on( + "User", + UserFields.age, + UserFields.role, + UserFields.metafield(key="meta"), + ) + .on( + "Admin", + AdminFields.privileges, + AdminFields.custom_field(), + AdminFields.metafield(key="meta"), + ) + ) + expected_query = ( + "person(person_id: $person_id_0) {\n" + " id\n" + " name\n" + " email\n" + " ... on User {\n" + " age\n" + " role\n" + " metafield(key: $key_0)\n" + " }\n" + " ... on Admin {\n" + " privileges\n" + " custom_field\n" + " metafield(key: $key_0_1)\n" + " }\n" + "}" + ) + + built_query = print_ast(query.to_ast(0)) + + assert built_query == expected_query + assert query.get_formatted_variables() == { + "person_id_0": {"name": "person_id", "type": "ID!", "value": "1"}, + "key_0": {"name": "key", "type": "String!", "value": "meta"}, + "key_0_1": {"name": "key", "type": "String!", "value": "meta"}, + } + + +def test_get_people(): + query = ( + Query.people() + .fields( + PersonInterfaceInterface.id, + PersonInterfaceInterface.name, + PersonInterfaceInterface.email, + ) + .on("User", UserFields.age, UserFields.role) + .on("Admin", AdminFields.privileges) + ) + expected_query = ( + "people {\n" + " id\n" + " name\n" + " email\n" + " ... on User {\n" + " age\n" + " role\n" + " }\n" + " ... on Admin {\n" + " privileges\n" + " }\n" + "}" + ) + + built_query = print_ast(query.to_ast(0)) + + assert built_query == expected_query + assert not query.get_formatted_variables() + + +def test_add_user_mutation(): + user_input = AddUserInput( + name="bob", + age=30, + email="bob@example.com", + role=Role.ADMIN, + createdAt="2024-06-07T00:00:00.000Z", + ) + mutation = Mutation.add_user(user_input=user_input).fields( + UserFields.id, + UserFields.name, + UserFields.age, + UserFields.email, + UserFields.role, + UserFields.created_at, + ) + expected_mutation = ( + "addUser(" + "user_input: $user_input_0" + ") {\n" + " id\n" + " name\n" + " age\n" + " email\n" + " role\n" + " createdAt\n" + "}" + ) + + built_mutation = print_ast(mutation.to_ast(0)) + + assert built_mutation == expected_mutation + assert mutation.get_formatted_variables() == { + "user_input_0": { + "name": "user_input", + "type": "AddUserInput!", + "value": user_input, + } + } + + +def test_update_user_mutation(): + user_input = UpdateUserInput( + name="Alice Updated", + age=25, + email="alice.updated@example.com", + role=Role.USER, + createdAt="2024-06-07T00:00:00.000Z", + ) + mutation = Mutation.update_user( + user_id="1", + user_input=user_input, + ).fields( + UserFields.id, + UserFields.name, + UserFields.age, + UserFields.email, + UserFields.role, + UserFields.created_at, + ) + expected_mutation = ( + "updateUser(" + "user_id: $user_id_0, " + "user_input: $user_input_0" + ") {\n" + " id\n" + " name\n" + " age\n" + " email\n" + " role\n" + " createdAt\n" + "}" + ) + + built_mutation = print_ast(mutation.to_ast(0)) + + assert built_mutation == expected_mutation + assert mutation.get_formatted_variables() == { + "user_id_0": {"name": "user_id", "type": "ID!", "value": "1"}, + "user_input_0": { + "name": "user_input", + "type": "UpdateUserInput!", + "value": user_input, + }, + } + + +def test_delete_user_mutation(): + mutation = Mutation.delete_user(user_id="1").fields( + UserFields.id, + UserFields.name, + ) + expected_mutation = "deleteUser(user_id: $user_id_0) {\n id\n name\n}" + + built_mutation = print_ast(mutation.to_ast(0)) + + assert built_mutation == expected_mutation + assert mutation.get_formatted_variables() == { + "user_id_0": {"name": "user_id", "type": "ID!", "value": "1"} + } + + +def test_add_post_mutation(): + mutation = Mutation.add_post( + title="New Post", + content="This is the content", + author_id="1", + published_at="2024-06-07T00:00:00.000Z", + ).fields( + PostFields.id, + PostFields.title, + PostFields.content, + PostFields.author().fields( + PersonInterfaceInterface.id, + PersonInterfaceInterface.name, + ), + PostFields.published_at, + ) + expected_mutation = ( + "addPost(\n" + " title: $title_0\n" + " content: $content_0\n" + " authorId: $authorId_0\n" + " publishedAt: $publishedAt_0\n" + ") {\n" + " id\n" + " title\n" + " content\n" + " author {\n" + " id\n" + " name\n" + " }\n" + " publishedAt\n" + "}" + ) + + built_mutation = print_ast(mutation.to_ast(0)) + + assert built_mutation == expected_mutation + assert mutation.get_formatted_variables() == { + "title_0": { + "name": "title", + "type": "String!", + "value": "New Post", + }, + "content_0": { + "name": "content", + "type": "String!", + "value": "This is the content", + }, + "authorId_0": { + "name": "authorId", + "type": "ID!", + "value": "1", + }, + "publishedAt_0": { + "name": "publishedAt", + "type": "String!", + "value": "2024-06-07T00:00:00.000Z", + }, + } + + +def test_update_post_mutation(): + mutation = Mutation.update_post( + post_id="1", + title="Updated Title", + content="Updated Content", + published_at="2024-06-07T00:00:00.000Z", + ).fields( + PostFields.id, + PostFields.title, + PostFields.content, + PostFields.published_at, + ) + expected_mutation = ( + "updatePost(\n" + " post_id: $post_id_0\n" + " title: $title_0\n" + " content: $content_0\n" + " publishedAt: $publishedAt_0\n" + ") {\n" + " id\n" + " title\n" + " content\n" + " publishedAt\n" + "}" + ) + + built_mutation = print_ast(mutation.to_ast(0)) + + assert built_mutation == expected_mutation + assert mutation.get_formatted_variables() == { + "post_id_0": {"name": "post_id", "type": "ID!", "value": "1"}, + "title_0": {"name": "title", "type": "String", "value": "Updated Title"}, + "content_0": {"name": "content", "type": "String", "value": "Updated Content"}, + "publishedAt_0": { + "name": "publishedAt", + "type": "String", + "value": "2024-06-07T00:00:00.000Z", + }, + } + + +def test_delete_post_mutation(): + mutation = Mutation.delete_post(post_id="1").fields( + PostFields.id, + PostFields.title, + ) + expected_mutation = "deletePost(post_id: $post_id_0) {\n id\n title\n}" + + built_mutation = print_ast(mutation.to_ast(0)) + + assert built_mutation == expected_mutation + assert mutation.get_formatted_variables() == { + "post_id_0": {"name": "post_id", "type": "ID!", "value": "1"} + } + + +def test_user_specific_fields(): + query = Query.user(user_id="1").fields(UserFields.id, UserFields.name) + expected_query = "user(user_id: $user_id_0) {\n id\n name\n}" + + built_query = print_ast(query.to_ast(0)) + + assert built_query == expected_query + assert query.get_formatted_variables() == { + "user_id_0": {"name": "user_id", "type": "ID!", "value": "1"} + } + + +def test_user_with_friends_specific_fields(): + query = Query.user(user_id="1").fields( + UserFields.id, + UserFields.name, + UserFields.friends().fields(UserFields.id, UserFields.name), + UserFields.created_at, + ) + expected_query = ( + "user(user_id: $user_id_0) {\n" + " id\n" + " name\n" + " friends {\n" + " id\n" + " name\n" + " }\n" + " createdAt\n" + "}" + ) + + built_query = print_ast(query.to_ast(0)) + + assert built_query == expected_query + assert query.get_formatted_variables() == { + "user_id_0": {"name": "user_id", "type": "ID!", "value": "1"} + } + + +def test_people_with_metadata(): + query = ( + Query.people() + .fields( + PersonInterfaceInterface.id, + PersonInterfaceInterface.name, + PersonInterfaceInterface.email, + PersonInterfaceInterface.metafield(key="bio"), + PersonInterfaceInterface.metafield(key="ots"), + ) + .on("User", UserFields.age, UserFields.role) + ) + expected_query = ( + "people {\n" + " id\n" + " name\n" + " email\n" + " metafield(key: $key_0)\n" + " metafield(key: $key_0_1)\n" + " ... on User {\n" + " age\n" + " role\n" + " }\n" + "}" + ) + + built_query = print_ast(query.to_ast(0)) + + assert built_query == expected_query + assert query.get_formatted_variables() == { + "key_0": {"name": "key", "type": "String!", "value": "bio"}, + "key_0_1": {"name": "key", "type": "String!", "value": "ots"}, + } diff --git a/tests/main/test_main.py b/tests/main/test_main.py index 9b94f1a7..2182a1de 100644 --- a/tests/main/test_main.py +++ b/tests/main/test_main.py @@ -197,6 +197,14 @@ def test_main_shows_version(): "interface_as_fragment", CLIENTS_PATH / "interface_as_fragment" / "expected_client", ), + ( + ( + CLIENTS_PATH / "custom_query_builder" / "pyproject.toml", + (CLIENTS_PATH / "custom_query_builder" / "schema.graphql",), + ), + "example_client", + CLIENTS_PATH / "custom_query_builder" / "expected_client", + ), ], indirect=["project_dir"], )