Skip to content

Commit

Permalink
Merge pull request #283 from bombsimon/fix/input-type-model-rebuild
Browse files Browse the repository at this point in the history
fix: Include `model_rebuild` for input types as well
  • Loading branch information
rafalp authored Mar 6, 2024
2 parents ba664b5 + d1f89c1 commit c14dd92
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 24 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# CHANGELOG

## 0.14.0 (Unreleased)

- Re-added `model_rebuild` calls for input types with forward references.


## 0.13.0 (2024-03-4)

- Fixed `str_to_snake_case` utility to capture fully capitalized words followed by an underscore.
Expand Down
16 changes: 14 additions & 2 deletions ariadne_codegen/client_generators/input_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
generate_ann_assign,
generate_class_def,
generate_constant,
generate_expr,
generate_import_from,
generate_keyword,
generate_method_call,
generate_module,
generate_pydantic_field,
model_has_forward_refs,
)
from ..plugins.manager import PluginManager
from ..utils import process_name
Expand All @@ -28,6 +31,7 @@
BASE_MODEL_IMPORT,
FIELD_CLASS,
LIST,
MODEL_REBUILD_METHOD,
OPTIONAL,
PLAIN_SERIALIZER,
PYDANTIC_MODULE,
Expand Down Expand Up @@ -85,8 +89,16 @@ def generate(self, types_to_include: Optional[List[str]] = None) -> ast.Module:
scalar_data = self.custom_scalars[scalar_name]
self._imports.extend(generate_scalar_imports(scalar_data))

module_body = cast(List[ast.stmt], self._imports) + cast(
List[ast.stmt], class_defs
model_rebuild_calls = [
generate_expr(generate_method_call(class_def.name, MODEL_REBUILD_METHOD))
for class_def in class_defs
if model_has_forward_refs(class_def)
]

module_body = (
cast(List[ast.stmt], self._imports)
+ cast(List[ast.stmt], class_defs)
+ cast(List[ast.stmt], model_rebuild_calls)
)
module = generate_module(body=module_body)

Expand Down
24 changes: 2 additions & 22 deletions ariadne_codegen/client_generators/result_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
generate_module,
generate_pass,
generate_pydantic_field,
model_has_forward_refs,
)
from ..exceptions import NotSupported, ParsingError
from ..plugins.manager import PluginManager
Expand Down Expand Up @@ -158,7 +159,7 @@ def generate(self) -> ast.Module:
model_rebuild_calls = [
generate_expr(generate_method_call(class_def.name, MODEL_REBUILD_METHOD))
for class_def in self._class_defs
if self.include_model_rebuild(class_def)
if model_has_forward_refs(class_def)
]

module_body = (
Expand All @@ -174,11 +175,6 @@ def generate(self) -> ast.Module:
)
return module

def include_model_rebuild(self, class_def: ast.ClassDef) -> bool:
visitor = ClassDefNamesVisitor()
visitor.visit(class_def)
return visitor.found_name_with_quote

def get_imports(self) -> List[ast.ImportFrom]:
return self._imports

Expand Down Expand Up @@ -576,19 +572,3 @@ def enter_field(node: FieldNode, *_args: Any) -> FieldNode:
copied_node = deepcopy(node)
visit(copied_node, RemoveMixinVisitor())
return copied_node


class ClassDefNamesVisitor(ast.NodeVisitor):
def __init__(self):
self.found_name_with_quote = False

def visit_Name(self, node): # pylint: disable=C0103
if '"' in node.id:
self.found_name_with_quote = True
self.generic_visit(node)

def visit_Subscript(self, node): # pylint: disable=C0103
if isinstance(node.value, ast.Name) and node.value.id == "Literal":
return

self.generic_visit(node)
22 changes: 22 additions & 0 deletions ariadne_codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,3 +332,25 @@ def generate_yield(value: Optional[ast.expr] = None) -> ast.Yield:

def generate_pass() -> ast.Pass:
return ast.Pass()


def model_has_forward_refs(class_def: ast.ClassDef) -> bool:
visitor = ClassDefNamesVisitor()
visitor.visit(class_def)
return visitor.found_name_with_quote


class ClassDefNamesVisitor(ast.NodeVisitor):
def __init__(self):
self.found_name_with_quote = False

def visit_Name(self, node): # pylint: disable=C0103
if '"' in node.id:
self.found_name_with_quote = True
self.generic_visit(node)

def visit_Subscript(self, node): # pylint: disable=C0103
if isinstance(node.value, ast.Name) and node.value.id == "Literal":
return

self.generic_visit(node)
4 changes: 4 additions & 0 deletions tests/main/clients/example/expected_client/input_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,7 @@ class NotificationsPreferencesInput(BaseModel):
receive_push_notifications: bool = Field(alias="receivePushNotifications")
receive_sms: bool = Field(alias="receiveSms")
title: str


UserCreateInput.model_rebuild()
UserPreferencesInput.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,8 @@ class InputAB(BaseModel):

class InputE(BaseModel):
val: EnumE


InputA.model_rebuild()
InputAA.model_rebuild()
InputAB.model_rebuild()

0 comments on commit c14dd92

Please sign in to comment.