From 94ff3de3fd84ade9f06af60effcb8d6faff81f4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Pito=C5=84?= Date: Thu, 29 Feb 2024 18:02:52 +0100 Subject: [PATCH] Add hack for updating final result with data from root value (#1170) --- ariadne/graphql.py | 57 ++++++++++++++++++++++++++++++++++--------- ariadne/types.py | 30 +++++++++++++++++++++++ tests/conftest.py | 1 + tests/test_graphql.py | 36 ++++++++++++++++++++++++++- 4 files changed, 111 insertions(+), 13 deletions(-) diff --git a/ariadne/graphql.py b/ariadne/graphql.py index c17f885a0..474581acc 100644 --- a/ariadne/graphql.py +++ b/ariadne/graphql.py @@ -36,6 +36,7 @@ from .format_error import format_error from .logger import log_error from .types import ( + BaseProxyRootValue, ErrorFormatter, ExtensionList, GraphQLResult, @@ -146,6 +147,8 @@ async def graphql( `**kwargs`: any kwargs not used by `graphql` are passed to `graphql.graphql`. """ + result_update: Optional[BaseProxyRootValue] = None + extension_manager = ExtensionManager(extensions, context_value) with extension_manager.request(): @@ -200,7 +203,11 @@ async def graphql( if isawaitable(root_value): root_value = await root_value - result = execute( + if isinstance(root_value, BaseProxyRootValue): + result_update = root_value + root_value = root_value.root_value + + exec_result = execute( schema, document, root_value=root_value, @@ -214,10 +221,10 @@ async def graphql( **kwargs, ) - if isawaitable(result): - result = await cast(Awaitable[ExecutionResult], result) + if isawaitable(exec_result): + exec_result = await cast(Awaitable[ExecutionResult], exec_result) except GraphQLError as error: - return handle_graphql_errors( + error_result = handle_graphql_errors( [error], logger=logger, error_formatter=error_formatter, @@ -225,14 +232,24 @@ async def graphql( extension_manager=extension_manager, ) - return handle_query_result( - result, + if result_update: + return result_update.update_result(error_result) + + return error_result + + result = handle_query_result( + exec_result, logger=logger, error_formatter=error_formatter, debug=debug, extension_manager=extension_manager, ) + if result_update: + return result_update.update_result(result) + + return result + def graphql_sync( schema: GraphQLSchema, @@ -321,6 +338,8 @@ def graphql_sync( `**kwargs`: any kwargs not used by `graphql_sync` are passed to `graphql.graphql_sync`. """ + result_update: Optional[BaseProxyRootValue] = None + extension_manager = ExtensionManager(extensions, context_value) with extension_manager.request(): @@ -379,7 +398,11 @@ def graphql_sync( "in synchronous query executor." ) - result = execute_sync( + if isinstance(root_value, BaseProxyRootValue): + result_update = root_value + root_value = root_value.root_value + + exec_result = execute_sync( schema, document, root_value=root_value, @@ -393,13 +416,13 @@ def graphql_sync( **kwargs, ) - if isawaitable(result): - ensure_future(cast(Awaitable[ExecutionResult], result)).cancel() + if isawaitable(exec_result): + ensure_future(cast(Awaitable[ExecutionResult], exec_result)).cancel() raise RuntimeError( "GraphQL execution failed to complete synchronously." ) except GraphQLError as error: - return handle_graphql_errors( + error_result = handle_graphql_errors( [error], logger=logger, error_formatter=error_formatter, @@ -407,14 +430,24 @@ def graphql_sync( extension_manager=extension_manager, ) - return handle_query_result( - result, + if result_update: + return result_update.update_result(error_result) + + return error_result + + result = handle_query_result( + exec_result, logger=logger, error_formatter=error_formatter, debug=debug, extension_manager=extension_manager, ) + if result_update: + return result_update.update_result(result) + + return result + async def subscribe( schema: GraphQLSchema, diff --git a/ariadne/types.py b/ariadne/types.py index 3dc21f01a..f77211d59 100644 --- a/ariadne/types.py +++ b/ariadne/types.py @@ -34,6 +34,7 @@ "ErrorFormatter", "ContextValue", "RootValue", + "BaseProxyRootValue", "QueryParser", "QueryValidator", "ValidationRules", @@ -228,6 +229,35 @@ async def get_context_value(request: Request, _): Callable[[Optional[Any], Optional[str], Optional[dict], DocumentNode], Any], ] + +class BaseProxyRootValue: + """A `RootValue` wrapper that includes result JSON update logic. + + Can be returned by the `RootValue` callable. Not used by Ariadne directly + but part of the support for Ariadne GraphQL Proxy. + + # Attributes + + - `root_value: Optional[dict]`: `RootValue` to use during query execution. + """ + + __slots__ = ("root_value",) + + root_value: Optional[dict] + + def __init__(self, root_value: Optional[dict] = None): + self.root_value = root_value + + def update_result(self, result: GraphQLResult) -> GraphQLResult: + """An update function used to create a final `GraphQL` result tuple to + create a JSON response from. + + Default implementation in `BaseProxyRootValue` is a passthrough that + returns `result` value without any changes. + """ + return result + + """Type of `query_parser` option of GraphQL servers. Enables customization of server's GraphQL parsing logic. If not set or `None`, diff --git a/tests/conftest.py b/tests/conftest.py index a48173b72..9333a7824 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,6 +23,7 @@ def type_defs(): testContext: String testRoot: String testError: Boolean + context: String } type Mutation { diff --git a/tests/test_graphql.py b/tests/test_graphql.py index 2fd4ea4a1..eb110dfc0 100644 --- a/tests/test_graphql.py +++ b/tests/test_graphql.py @@ -3,6 +3,7 @@ from graphql.validation.rules import ValidationRule from ariadne import graphql, graphql_sync, subscribe +from ariadne.types import BaseProxyRootValue class AlwaysInvalid(ValidationRule): @@ -12,6 +13,12 @@ def leave_operation_definition( # pylint: disable=unused-argument self.context.report_error(GraphQLError("Invalid")) +class ProxyRootValue(BaseProxyRootValue): + def update_result(self, result): + success, data = result + return success, {"updated": True, **data} + + def test_graphql_sync_executes_the_query(schema): success, result = graphql_sync(schema, {"query": '{ hello(name: "world") }'}) assert success @@ -51,8 +58,21 @@ def test_graphql_sync_prevents_introspection_query_when_option_is_disabled(schem ) +def test_graphql_sync_executes_the_query_using_result_update_obj(schema): + success, result = graphql_sync( + schema, + {"query": "{ context }"}, + root_value=ProxyRootValue({"context": "Works!"}), + ) + assert success + assert result == { + "data": {"context": "Works!"}, + "updated": True, + } + + @pytest.mark.asyncio -async def test_graphql_execute_the_query(schema): +async def test_graphql_executes_the_query(schema): success, result = await graphql(schema, {"query": '{ hello(name: "world") }'}) assert success assert result["data"] == {"hello": "Hello, world!"} @@ -94,6 +114,20 @@ async def test_graphql_prevents_introspection_query_when_option_is_disabled(sche ) +@pytest.mark.asyncio +async def test_graphql_executes_the_query_using_result_update_obj(schema): + success, result = await graphql( + schema, + {"query": "{ context }"}, + root_value=ProxyRootValue({"context": "Works!"}), + ) + assert success + assert result == { + "data": {"context": "Works!"}, + "updated": True, + } + + @pytest.mark.asyncio async def test_subscription_returns_an_async_iterator(schema): success, result = await subscribe(schema, {"query": "subscription { ping }"})