Skip to content

Commit

Permalink
Add support to interface and union type
Browse files Browse the repository at this point in the history
  • Loading branch information
DamianCzajkowski committed Jun 6, 2024
1 parent 52e9aa6 commit a6019da
Show file tree
Hide file tree
Showing 18 changed files with 34,548 additions and 239 deletions.
68 changes: 46 additions & 22 deletions ariadne_codegen/client_generators/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def __init__(

self._imports: List[Union[ast.ImportFrom, ast.Import]] = []
self._add_import(
generate_import_from([OPTIONAL, LIST, DICT, ANY, UNION, ASYNC_ITERATOR], TYPING_MODULE)
generate_import_from(
[OPTIONAL, LIST, DICT, ANY, UNION, ASYNC_ITERATOR], TYPING_MODULE
)
)
self._add_import(base_client_import)
self._add_import(unset_import)
Expand Down Expand Up @@ -151,17 +153,19 @@ def add_method(
operation_name = definition.name.value if definition.name else ""
if definition.operation == OperationType.SUBSCRIPTION:
if not async_:
raise NotSupported("Subscriptions are only available when using async client.")
method_def: Union[
ast.FunctionDef, ast.AsyncFunctionDef
] = self._generate_subscription_method_def(
name=name,
operation_name=operation_name,
return_type=return_type,
arguments=arguments,
arguments_dict=arguments_dict,
operation_str=operation_str,
variable_names=variable_names,
raise NotSupported(
"Subscriptions are only available when using async client."
)
method_def: Union[ast.FunctionDef, ast.AsyncFunctionDef] = (
self._generate_subscription_method_def(
name=name,
operation_name=operation_name,
return_type=return_type,
arguments=arguments,
arguments_dict=arguments_dict,
operation_str=operation_str,
variable_names=variable_names,
)
)
elif async_:
method_def = self._generate_async_method(
Expand Down Expand Up @@ -223,21 +227,27 @@ def add_execute_custom_operation_method(self):
)
],
keywords=[
generate_keyword(arg="operation_name", value=generate_name("operation_name"))
generate_keyword(
arg="operation_name", value=generate_name("operation_name")
)
],
)
)

operation_definition_node = generate_call(
func=generate_name("OperationDefinitionNode"),
keywords=[
generate_keyword(arg="operation", value=generate_name("operation_type")),
generate_keyword(
arg="operation", value=generate_name("operation_type")
),
generate_keyword(
arg="name",
value=generate_call(
func=generate_name("NameNode"),
keywords=[
generate_keyword(arg="value", value=generate_name("operation_name"))
generate_keyword(
arg="value", value=generate_name("operation_name")
)
],
),
),
Expand Down Expand Up @@ -379,7 +389,9 @@ def get_variable_names(self, arguments: ast.arguments) -> Dict[str, str]:
argument_names = set(arg.arg for arg in arguments.args)

for variable in mapped_variable_names:
variable_names[variable] = f"_{variable}" if variable in argument_names else variable
variable_names[variable] = (
f"_{variable}" if variable in argument_names else variable
)

return variable_names

Expand Down Expand Up @@ -469,7 +481,9 @@ def _generate_operation_str_assign(
targets=[variable_names[self._operation_str_variable]],
value=generate_call(
func=generate_name(self._gql_func_name),
args=[[generate_constant(l + "\n") for l in operation_str.splitlines()]],
args=[
[generate_constant(l + "\n") for l in operation_str.splitlines()]
],
),
lineno=lineno,
)
Expand All @@ -492,7 +506,9 @@ def _generate_async_response_assign(
) -> ast.Assign:
return generate_assign(
targets=[variable_names[self._response_variable]],
value=generate_await(self._generate_execute_call(variable_names, operation_name)),
value=generate_await(
self._generate_execute_call(variable_names, operation_name)
),
lineno=lineno,
)

Expand All @@ -518,7 +534,9 @@ def _generate_execute_call(
value=generate_name(variable_names[self._operation_str_variable]),
arg="query",
),
generate_keyword(value=generate_constant(operation_name), arg="operation_name"),
generate_keyword(
value=generate_constant(operation_name), arg="operation_name"
),
generate_keyword(
value=generate_name(variable_names[self._variables_dict_variable]),
arg="variables",
Expand All @@ -541,7 +559,9 @@ def _generate_return_parsed_obj(
) -> ast.Return:
return generate_return(
generate_call(
func=generate_attribute(generate_name(return_type), MODEL_VALIDATE_METHOD),
func=generate_attribute(
generate_name(return_type), MODEL_VALIDATE_METHOD
),
args=[generate_name(variable_names[self._data_variable])],
)
)
Expand All @@ -559,14 +579,18 @@ def _generate_async_generator_loop(
func=generate_attribute(value=generate_name("self"), attr="execute_ws"),
keywords=[
generate_keyword(
value=generate_name(variable_names[self._operation_str_variable]),
value=generate_name(
variable_names[self._operation_str_variable]
),
arg="query",
),
generate_keyword(
value=generate_constant(operation_name), arg="operation_name"
),
generate_keyword(
value=generate_name(variable_names[self._variables_dict_variable]),
value=generate_name(
variable_names[self._variables_dict_variable]
),
arg="variables",
),
generate_keyword(value=generate_name(KWARGS_NAMES)),
Expand Down
4 changes: 3 additions & 1 deletion ariadne_codegen/client_generators/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@

KWARGS_NAMES = "kwargs"

DEFAULT_ASYNC_BASE_CLIENT_PATH = Path(__file__).parent / "dependencies" / "async_base_client.py"
DEFAULT_ASYNC_BASE_CLIENT_PATH = (
Path(__file__).parent / "dependencies" / "async_base_client.py"
)
DEFAULT_ASYNC_BASE_CLIENT_NAME = "AsyncBaseClient"

DEFAULT_ASYNC_BASE_CLIENT_OPEN_TELEMETRY_PATH = (
Expand Down
47 changes: 33 additions & 14 deletions ariadne_codegen/client_generators/custom_fields.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import ast
from typing import Dict, List, Optional, Set, Union, cast

from graphql import GraphQLObjectType, GraphQLSchema
from graphql import (
GraphQLInterfaceType,
GraphQLObjectType,
GraphQLSchema,
GraphQLUnionType,
)

from ..codegen import (
generate_ann_assign,
Expand All @@ -13,7 +18,6 @@
generate_constant,
generate_expr,
generate_import_from,
generate_keyword,
generate_method_definition,
generate_module,
generate_name,
Expand Down Expand Up @@ -70,23 +74,31 @@ def generate(self) -> ast.Module:

def _parse_object_type_definitions(self, class_definitions):
class_defs = []
interface_defs = []
for type_name in class_definitions:
graphql_type = self.schema.get_type(type_name)
if isinstance(graphql_type, GraphQLObjectType):
class_def = self._parse_graphql_types_definition(
graphql_type, "GraphQLField"
class_def = self._generate_class_def_body(
definition=graphql_type,
class_name=f"{graphql_type.name}Fields",
)
class_defs.append(class_def)
return class_defs

def _parse_graphql_types_definition(
self, definition: GraphQLObjectType, base_name
return [*interface_defs, *class_defs]

def _generate_class_def_body(
self,
definition: Union[GraphQLObjectType, GraphQLInterfaceType],
class_name: str,
) -> ast.ClassDef:
class_name = f"{definition.name}Fields"
class_def = generate_class_def(name=class_name, base_names=[base_name])
base_names = ["GraphQLField"]
additional_fields_typing = set()
definition_fields: Dict[str, ast.ClassDef] = dict(definition.fields.items())
for interface in definition.interfaces:
definition_fields.update(dict(interface.fields.items()))
class_def = generate_class_def(name=class_name, base_names=base_names)

for lineno, (org_name, field) in enumerate(definition.fields.items(), start=1):
for lineno, (org_name, field) in enumerate(definition_fields.items(), start=1):
name = process_name(
org_name,
convert_to_snake_case=self.convert_to_snake_case,
Expand All @@ -98,10 +110,16 @@ def _parse_graphql_types_definition(
)
additional_fields_typing.add(f"{final_type.name}Fields")
else:
self._add_import(
generate_import_from([f"{definition.name}GraphQLField"], level=1)
)
field_class_name = generate_name(f"{definition.name}GraphQLField")
field_name = f"{definition.name}GraphQLField"
if isinstance(final_type, GraphQLInterfaceType):
field_name = f"{final_type.name}Interface"
additional_fields_typing.add(field_name)
if isinstance(final_type, GraphQLUnionType):
field_name = f"{final_type.name}Union"
additional_fields_typing.add(field_name)
self._add_import(generate_import_from([field_name], level=1))
field_class_name = generate_name(field_name)

field_implementation = generate_ann_assign(
target=name,
annotation=field_class_name,
Expand All @@ -119,6 +137,7 @@ def _parse_graphql_types_definition(
class_name, definition.name, additional_fields_typing
)
)

return class_def

def _generate_fields_method(
Expand Down
69 changes: 64 additions & 5 deletions ariadne_codegen/client_generators/custom_fields_typing.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
import ast
from typing import List, cast
from typing import List, Union, cast

from graphql import (
GraphQLInterfaceType,
GraphQLObjectType,
GraphQLSchema,
GraphQLUnionType,
)

from ariadne_codegen.client_generators.utils import get_final_type
from graphql import GraphQLObjectType, GraphQLSchema

from ..codegen import generate_class_def, generate_module
from ..codegen import (
generate_arg,
generate_arguments,
generate_attribute,
generate_class_def,
generate_method_definition,
generate_module,
generate_name,
generate_return,
generate_subscript,
)
from .constants import BASE_OPERATION_FILE_PATH, OPERATION_TYPES


Expand Down Expand Up @@ -36,21 +52,64 @@ def _filter_types(self):
return [
get_final_type(definition)
for name, definition in self.schema.type_map.items()
if isinstance(definition, GraphQLObjectType)
if isinstance(
definition, (GraphQLObjectType, GraphQLInterfaceType, GraphQLUnionType)
)
and not name.startswith("__")
and name not in OPERATION_TYPES
]

def _generate_field_class(self, class_def: ast.ClassDef) -> ast.ClassDef:
class_name = f"{class_def.name}GraphQLField"
class_body: List[ast.stmt] = []
if isinstance(class_def, GraphQLInterfaceType):
class_name = f"{class_def.name}Interface"
class_body.append(self._generate_on_method(class_name))
if isinstance(class_def, GraphQLUnionType):
class_name = f"{class_def.name}Union"
class_body.append(self._generate_on_method(class_name))
if class_name not in self._public_names:
self._public_names.append(class_name)
field_class_def = generate_class_def(
name=class_name,
base_names=["GraphQLField"],
body=[ast.Pass()],
body=class_body if class_body else cast(List[ast.stmt], [ast.Pass()]),
)
return field_class_def

def _generate_on_method(self, class_name: str) -> ast.FunctionDef:
return generate_method_definition(
"on",
arguments=generate_arguments(
[
generate_arg(name="self"),
generate_arg(name="type_name", annotation=generate_name("str")),
generate_arg(
name="*subfields", annotation=generate_name("GraphQLField")
),
]
),
body=[
cast(
ast.stmt,
ast.Assign(
targets=[
generate_subscript(
value=generate_attribute(
value=generate_name("self"),
attr="_inline_fragments",
),
slice_=generate_name("type_name"),
)
],
value=generate_name("subfields"),
lineno=1,
),
),
generate_return(value=generate_name("self")),
],
return_type=generate_name(f'"{class_name}"'),
)

def get_generated_public_names(self) -> List[str]:
return self._public_names
Loading

0 comments on commit a6019da

Please sign in to comment.