Skip to content

Commit

Permalink
add limit check
Browse files Browse the repository at this point in the history
  • Loading branch information
trim21 committed Sep 22, 2024
1 parent 3f991f3 commit c4daad3
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 9 deletions.
22 changes: 13 additions & 9 deletions gql/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,18 @@
from chii.compat import phpseralize
from chii.const import CollectionType
from chii.db import sa
from chii.db.tables import (
ChiiTimeline,
ChiiTimeline_column_cat,
ChiiTimeline_column_id,
)
from chii.db.tables import ChiiTimeline, ChiiTimeline_column_cat, ChiiTimeline_column_id
from chii.timeline import TimelineCat
from gql.model import CollectTimeline
from gql.rules import depth_limit_validator

# Define types using Schema Definition Language (https://graphql.org/learn/schema/)
# Wrapping string in gql function provides validation and better error traceback
type_defs = gql(
Path(__file__, "..", "schema.graphql").resolve().read_text(encoding="utf8")
)

CessionMaker = sa.async_session_maker()
CreateSession = sa.async_session_maker()

# Map resolver functions to Query fields using QueryType
query = QueryType()
Expand All @@ -33,8 +30,8 @@
# Resolvers are simple python functions
@query.field("timeline_collection")
async def timeline_collection(*_: Any) -> list[CollectTimeline]:
async with CessionMaker() as session:
rows: list[ChiiTimeline] = await session.execute(
async with CreateSession() as session:
rows: list[tuple[ChiiTimeline,]] = await session.execute(
select(ChiiTimeline)
.where(ChiiTimeline_column_cat == TimelineCat.Subject)
.order_by(ChiiTimeline_column_id.desc())
Expand Down Expand Up @@ -77,6 +74,13 @@ async def timeline_collection(*_: Any) -> list[CollectTimeline]:
app = Starlette(
debug=True,
routes=[
Mount("/graphql", GraphQL(schema, debug=True)),
Mount(
"/graphql",
GraphQL(
schema,
debug=True,
validation_rules=[depth_limit_validator(max_depth=5)],
),
),
],
)
120 changes: 120 additions & 0 deletions gql/rules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from typing import Dict, Tuple

from ariadne.contrib.tracing.utils import is_introspection_key
from graphql import (
ASTValidationRule,
DefinitionNode,
FieldNode,
FragmentDefinitionNode,
FragmentSpreadNode,
GraphQLError,
InlineFragmentNode,
Node,
OperationDefinitionNode,
ValidationContext,
)
from graphql.validation.validate import ValidationAbortedError


def depth_limit_validator(max_depth: int):
class DepthLimitValidator(ASTValidationRule):
def __init__(self, validation_context: ValidationContext):
document = validation_context.document
definitions = document.definitions

fragments = get_fragments(definitions)
queries = get_queries_and_mutations(definitions)
query_depths = {}

for name in queries:
query_depths[name] = determine_depth(
node=queries[name],
fragments=fragments,
depth_so_far=0,
max_depth=max_depth,
context=validation_context,
operation_name=name,
)
super().__init__(validation_context)

return DepthLimitValidator


def get_fragments(
definitions: Tuple[DefinitionNode, ...],
) -> Dict[str, FragmentDefinitionNode]:
fragments = {}
for definition in definitions:
if isinstance(definition, FragmentDefinitionNode):
fragments[definition.name.value] = definition
return fragments


def get_queries_and_mutations(
definitions: Tuple[DefinitionNode, ...],
) -> Dict[str, OperationDefinitionNode]:
operations = {}

for definition in definitions:
if isinstance(definition, OperationDefinitionNode):
operation = definition.name.value if definition.name else "anonymous"
operations[operation] = definition
return operations


def determine_depth(
node: Node,
fragments: Dict[str, FragmentDefinitionNode],
depth_so_far: int,
max_depth: int,
context: ValidationContext,
operation_name: str,
) -> int:
if depth_so_far > max_depth:
context.report_error(
GraphQLError(
f"'{operation_name}' exceeds maximum operation depth of {max_depth}.",
[node],
)
)
raise ValidationAbortedError
if isinstance(node, FieldNode):
should_ignore = is_introspection_key(node.name.value)

if should_ignore or not node.selection_set:
return 0
return 1 + max(
determine_depth(
node=selection,
fragments=fragments,
depth_so_far=depth_so_far + 1,
max_depth=max_depth,
context=context,
operation_name=operation_name,
)
for selection in node.selection_set.selections
)
if isinstance(node, FragmentSpreadNode):
return determine_depth(
node=fragments[node.name.value],
fragments=fragments,
depth_so_far=depth_so_far,
max_depth=max_depth,
context=context,
operation_name=operation_name,
)
if isinstance(
node, (InlineFragmentNode, FragmentDefinitionNode, OperationDefinitionNode)
):
return max(
determine_depth(
node=selection,
fragments=fragments,
depth_so_far=depth_so_far,
max_depth=max_depth,
context=context,
operation_name=operation_name,
)
for selection in node.selection_set.selections
)
raise Exception(f"Depth crawler cannot handle: {node.kind}.") # pragma: no cover

0 comments on commit c4daad3

Please sign in to comment.