Skip to content

Commit

Permalink
fix: handle @include , @Skip directives when checking user queried fi…
Browse files Browse the repository at this point in the history
…elds
  • Loading branch information
mak626 committed Dec 19, 2023
1 parent 8ef1805 commit b31b0b9
Showing 1 changed file with 57 additions and 13 deletions.
70 changes: 57 additions & 13 deletions graphene_mongo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,17 @@
from typing import Any, Callable, Union

import mongoengine
from asgiref.sync import sync_to_async as asgiref_sync_to_async
from asgiref.sync import SyncToAsync
from asgiref.sync import sync_to_async as asgiref_sync_to_async
from graphene import Node
from graphene.utils.trim_docstring import trim_docstring
from graphql import FieldNode
from graphql import (
BooleanValueNode,
FieldNode,
GraphQLIncludeDirective,
GraphQLSkipDirective,
VariableNode,
)
from graphql_relay.connection.array_connection import offset_to_cursor


Expand Down Expand Up @@ -112,12 +118,44 @@ def get_node_from_global_id(node, info, global_id):
return Node.get_node_from_global_id(info, global_id)


def collect_query_fields(node, fragments):
def include_field_by_directives(node, variables):
"""
Evaluates the graphql directives to determine if the queried field is to be included
Handles Directives
@skip
@include
"""
directives = node.get("directives") if isinstance(node, dict) else node.directives
if not directives:
return True

directive_results = []
for directive in directives:
argument_results = []
for argument in directive.arguments:
if isinstance(argument.value, BooleanValueNode):
argument_results.append(argument.value.value)
elif isinstance(argument.value, VariableNode):
argument_results.append(variables.get(argument.value.name.value))

directive_name = directive.name.value
if directive_name == GraphQLIncludeDirective.name:
directive_results.append(True if any(argument_results) else False)
elif directive_name == GraphQLSkipDirective.name:
directive_results.append(False if all(argument_results) else True)

return all(directive_results) if len(directive_results) > 0 else True


def collect_query_fields(node, fragments, variables):
"""Recursively collects fields from the AST
Args:
node (dict): A node in the AST
fragments (dict): Fragment definitions
variables (dict): User defined variables & values
Returns:
A dict mapping each field found, along with their sub fields.
Expand All @@ -133,20 +171,23 @@ def collect_query_fields(node, fragments):
"""

field = {}
selection_set = None
if isinstance(node, dict):
selection_set = node.get("selection_set")
else:
selection_set = node.selection_set
selection_set = node.get("selection_set") if isinstance(node, dict) else node.selection_set
if selection_set:
for leaf in selection_set.selections:
if leaf.kind == "field":
field.update({leaf.name.value: collect_query_fields(leaf, fragments)})
if include_field_by_directives(leaf, variables):
field.update(
{leaf.name.value: collect_query_fields(leaf, fragments, variables)}
)
elif leaf.kind == "fragment_spread":
field.update(collect_query_fields(fragments[leaf.name.value], fragments))
field.update(collect_query_fields(fragments[leaf.name.value], fragments, variables))
elif leaf.kind == "inline_fragment":
field.update(
{leaf.type_condition.name.value: collect_query_fields(leaf, fragments)}
{
leaf.type_condition.name.value: collect_query_fields(
leaf, fragments, variables
)
}
)

return field
Expand All @@ -164,11 +205,12 @@ def get_query_fields(info):

fragments = {}
node = ast_to_dict(info.field_nodes[0])
variables = info.variable_values

for name, value in info.fragments.items():
fragments[name] = ast_to_dict(value)

query = collect_query_fields(node, fragments)
query = collect_query_fields(node, fragments, variables)
if "edges" in query:
return query["edges"]["node"].keys()
return query
Expand All @@ -189,10 +231,12 @@ def has_page_info(info):
if not info:
return True # Returning True if invalid info is provided
node = ast_to_dict(info.field_nodes[0])
variables = info.variable_values

for name, value in info.fragments.items():
fragments[name] = ast_to_dict(value)

query = collect_query_fields(node, fragments)
query = collect_query_fields(node, fragments, variables)
return next((True for x in query.keys() if x.lower() == "pageinfo"), False)


Expand Down

0 comments on commit b31b0b9

Please sign in to comment.