From c4daad32a6519db79950243062723b28d2fb1300 Mon Sep 17 00:00:00 2001 From: Trim21 Date: Sun, 22 Sep 2024 13:04:19 +0800 Subject: [PATCH] add limit check --- gql/app.py | 22 ++++++---- gql/rules.py | 120 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 133 insertions(+), 9 deletions(-) create mode 100644 gql/rules.py diff --git a/gql/app.py b/gql/app.py index e507359..33db32d 100644 --- a/gql/app.py +++ b/gql/app.py @@ -10,13 +10,10 @@ from chii.compat import phpseralize from chii.const import CollectionType from chii.db import sa -from chii.db.tables import ( - ChiiTimeline, - ChiiTimeline_column_cat, - ChiiTimeline_column_id, -) +from chii.db.tables import ChiiTimeline, ChiiTimeline_column_cat, ChiiTimeline_column_id from chii.timeline import TimelineCat from gql.model import CollectTimeline +from gql.rules import depth_limit_validator # Define types using Schema Definition Language (https://graphql.org/learn/schema/) # Wrapping string in gql function provides validation and better error traceback @@ -24,7 +21,7 @@ Path(__file__, "..", "schema.graphql").resolve().read_text(encoding="utf8") ) -CessionMaker = sa.async_session_maker() +CreateSession = sa.async_session_maker() # Map resolver functions to Query fields using QueryType query = QueryType() @@ -33,8 +30,8 @@ # Resolvers are simple python functions @query.field("timeline_collection") async def timeline_collection(*_: Any) -> list[CollectTimeline]: - async with CessionMaker() as session: - rows: list[ChiiTimeline] = await session.execute( + async with CreateSession() as session: + rows: list[tuple[ChiiTimeline,]] = await session.execute( select(ChiiTimeline) .where(ChiiTimeline_column_cat == TimelineCat.Subject) .order_by(ChiiTimeline_column_id.desc()) @@ -77,6 +74,13 @@ async def timeline_collection(*_: Any) -> list[CollectTimeline]: app = Starlette( debug=True, routes=[ - Mount("/graphql", GraphQL(schema, debug=True)), + Mount( + "/graphql", + GraphQL( + schema, + debug=True, + validation_rules=[depth_limit_validator(max_depth=5)], + ), + ), ], ) diff --git a/gql/rules.py b/gql/rules.py new file mode 100644 index 0000000..edf10db --- /dev/null +++ b/gql/rules.py @@ -0,0 +1,120 @@ +from typing import Dict, Tuple + +from ariadne.contrib.tracing.utils import is_introspection_key +from graphql import ( + ASTValidationRule, + DefinitionNode, + FieldNode, + FragmentDefinitionNode, + FragmentSpreadNode, + GraphQLError, + InlineFragmentNode, + Node, + OperationDefinitionNode, + ValidationContext, +) +from graphql.validation.validate import ValidationAbortedError + + +def depth_limit_validator(max_depth: int): + class DepthLimitValidator(ASTValidationRule): + def __init__(self, validation_context: ValidationContext): + document = validation_context.document + definitions = document.definitions + + fragments = get_fragments(definitions) + queries = get_queries_and_mutations(definitions) + query_depths = {} + + for name in queries: + query_depths[name] = determine_depth( + node=queries[name], + fragments=fragments, + depth_so_far=0, + max_depth=max_depth, + context=validation_context, + operation_name=name, + ) + super().__init__(validation_context) + + return DepthLimitValidator + + +def get_fragments( + definitions: Tuple[DefinitionNode, ...], +) -> Dict[str, FragmentDefinitionNode]: + fragments = {} + for definition in definitions: + if isinstance(definition, FragmentDefinitionNode): + fragments[definition.name.value] = definition + return fragments + + +def get_queries_and_mutations( + definitions: Tuple[DefinitionNode, ...], +) -> Dict[str, OperationDefinitionNode]: + operations = {} + + for definition in definitions: + if isinstance(definition, OperationDefinitionNode): + operation = definition.name.value if definition.name else "anonymous" + operations[operation] = definition + return operations + + +def determine_depth( + node: Node, + fragments: Dict[str, FragmentDefinitionNode], + depth_so_far: int, + max_depth: int, + context: ValidationContext, + operation_name: str, +) -> int: + if depth_so_far > max_depth: + context.report_error( + GraphQLError( + f"'{operation_name}' exceeds maximum operation depth of {max_depth}.", + [node], + ) + ) + raise ValidationAbortedError + if isinstance(node, FieldNode): + should_ignore = is_introspection_key(node.name.value) + + if should_ignore or not node.selection_set: + return 0 + return 1 + max( + determine_depth( + node=selection, + fragments=fragments, + depth_so_far=depth_so_far + 1, + max_depth=max_depth, + context=context, + operation_name=operation_name, + ) + for selection in node.selection_set.selections + ) + if isinstance(node, FragmentSpreadNode): + return determine_depth( + node=fragments[node.name.value], + fragments=fragments, + depth_so_far=depth_so_far, + max_depth=max_depth, + context=context, + operation_name=operation_name, + ) + if isinstance( + node, (InlineFragmentNode, FragmentDefinitionNode, OperationDefinitionNode) + ): + return max( + determine_depth( + node=selection, + fragments=fragments, + depth_so_far=depth_so_far, + max_depth=max_depth, + context=context, + operation_name=operation_name, + ) + for selection in node.selection_set.selections + ) + raise Exception(f"Depth crawler cannot handle: {node.kind}.") # pragma: no cover