Skip to content

Commit

Permalink
Split _update_imports
Browse files Browse the repository at this point in the history
  • Loading branch information
bombsimon committed Apr 2, 2024
1 parent 5791619 commit 4e06845
Showing 1 changed file with 36 additions and 16 deletions.
52 changes: 36 additions & 16 deletions ariadne_codegen/contrib/client_forward_refs.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,18 +251,13 @@ def _get_class_from_call(self, call: ast.Call) -> Optional[ast.Name]:

return call.func.value

def _update_imports(self, module: ast.Module):
def _update_imports(self, module: ast.Module) -> None:
"""Update all imports.
Iterate over all imports and remove the aliases that we use as input or
return value. These will be moved and added to an `if TYPE_CHECKING`
block.
**NOTE** If an `ast.ImportFrom` ends up without any names we must remove
it completely otherwise formatting will not work (it would remove the
empty `import from` but not format the rest of the code without running
it twice).
We do this by storing all imports that we want to keep in an array, we
then drop all from the body and re-insert the ones to keep. Lastly we
import `TYPE_CHECKING` and add all our imports in the `if TYPE_CHECKING`
Expand All @@ -286,12 +281,29 @@ def _update_imports(self, module: ast.Module):
if len(return_types_not_used_as_input) == 0:
return None

# We sadly have to iterate over all imports again and remove the imports
# we will do conditionally.
# It's very important that we get this right, if we keep any
# `ImportFrom` that ends up without any names, the formatting will not
# work! It will only remove the empty `import from` but not other unused
# imports.
non_empty_imports = self._update_existing_imports(
module, return_types_not_used_as_input
)
self._add_forward_ref_imports(module, non_empty_imports)

return None

def _update_existing_imports(
self, module: ast.Module, return_types_not_used_as_input: set[str]
) -> List[Union[ast.Import, ast.ImportFrom]]:
"""Update existing imports.
Remove all import or import from statements that would otherwise be
useless after moving them to forward refs.
It's very important that we get this right, if we keep any `ImportFrom`
that ends up without any names, the formatting will not work! It will
only remove the empty `import from` but not other unused imports.
:param module: The ast module to update
:param return_types_not_used_as_input: Set of return types not used as
input
"""
non_empty_imports: List[Union[ast.Import, ast.ImportFrom]] = []
last_import_at = 0
for i, node in enumerate(module.body):
Expand All @@ -316,8 +328,18 @@ def _update_imports(self, module: ast.Module):
# We can now remove all imports and re-insert the ones that's not empty.
module.body = non_empty_imports + module.body[last_import_at + 1 :]

# Create import to use for type checking. These will be put in an `if
# TYPE_CHECKING` block.
return non_empty_imports

def _add_forward_ref_imports(
self,
module: ast.Module,
non_empty_imports: List[Union[ast.Import, ast.ImportFrom]],
) -> None:
"""Add forward ref imports.
Add all the forward ref imports meaning all the types needed for type
checking under the `if TYPE_CHECKING` condition.
"""
type_checking_imports = {}
for cls in self.input_and_return_types:
module_name = self.imported_classes[cls]
Expand Down Expand Up @@ -345,8 +367,6 @@ def _update_imports(self, module: ast.Module):
),
)

return None

def _update_name_to_constant(self, node: ast.expr) -> ast.expr:
"""Update return types.
Expand Down

0 comments on commit 4e06845

Please sign in to comment.