Skip to content

Commit

Permalink
Add hack for updating final result with data from root value (#1170)
Browse files Browse the repository at this point in the history
  • Loading branch information
rafalp authored Feb 29, 2024
1 parent 85f538c commit 94ff3de
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 13 deletions.
57 changes: 45 additions & 12 deletions ariadne/graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .format_error import format_error
from .logger import log_error
from .types import (
BaseProxyRootValue,
ErrorFormatter,
ExtensionList,
GraphQLResult,
Expand Down Expand Up @@ -146,6 +147,8 @@ async def graphql(
`**kwargs`: any kwargs not used by `graphql` are passed to
`graphql.graphql`.
"""
result_update: Optional[BaseProxyRootValue] = None

extension_manager = ExtensionManager(extensions, context_value)

with extension_manager.request():
Expand Down Expand Up @@ -200,7 +203,11 @@ async def graphql(
if isawaitable(root_value):
root_value = await root_value

result = execute(
if isinstance(root_value, BaseProxyRootValue):
result_update = root_value
root_value = root_value.root_value

exec_result = execute(
schema,
document,
root_value=root_value,
Expand All @@ -214,25 +221,35 @@ async def graphql(
**kwargs,
)

if isawaitable(result):
result = await cast(Awaitable[ExecutionResult], result)
if isawaitable(exec_result):
exec_result = await cast(Awaitable[ExecutionResult], exec_result)
except GraphQLError as error:
return handle_graphql_errors(
error_result = handle_graphql_errors(
[error],
logger=logger,
error_formatter=error_formatter,
debug=debug,
extension_manager=extension_manager,
)

return handle_query_result(
result,
if result_update:
return result_update.update_result(error_result)

return error_result

result = handle_query_result(
exec_result,
logger=logger,
error_formatter=error_formatter,
debug=debug,
extension_manager=extension_manager,
)

if result_update:
return result_update.update_result(result)

return result


def graphql_sync(
schema: GraphQLSchema,
Expand Down Expand Up @@ -321,6 +338,8 @@ def graphql_sync(
`**kwargs`: any kwargs not used by `graphql_sync` are passed to
`graphql.graphql_sync`.
"""
result_update: Optional[BaseProxyRootValue] = None

extension_manager = ExtensionManager(extensions, context_value)

with extension_manager.request():
Expand Down Expand Up @@ -379,7 +398,11 @@ def graphql_sync(
"in synchronous query executor."
)

result = execute_sync(
if isinstance(root_value, BaseProxyRootValue):
result_update = root_value
root_value = root_value.root_value

exec_result = execute_sync(
schema,
document,
root_value=root_value,
Expand All @@ -393,28 +416,38 @@ def graphql_sync(
**kwargs,
)

if isawaitable(result):
ensure_future(cast(Awaitable[ExecutionResult], result)).cancel()
if isawaitable(exec_result):
ensure_future(cast(Awaitable[ExecutionResult], exec_result)).cancel()
raise RuntimeError(
"GraphQL execution failed to complete synchronously."
)
except GraphQLError as error:
return handle_graphql_errors(
error_result = handle_graphql_errors(
[error],
logger=logger,
error_formatter=error_formatter,
debug=debug,
extension_manager=extension_manager,
)

return handle_query_result(
result,
if result_update:
return result_update.update_result(error_result)

return error_result

result = handle_query_result(
exec_result,
logger=logger,
error_formatter=error_formatter,
debug=debug,
extension_manager=extension_manager,
)

if result_update:
return result_update.update_result(result)

return result


async def subscribe(
schema: GraphQLSchema,
Expand Down
30 changes: 30 additions & 0 deletions ariadne/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"ErrorFormatter",
"ContextValue",
"RootValue",
"BaseProxyRootValue",
"QueryParser",
"QueryValidator",
"ValidationRules",
Expand Down Expand Up @@ -228,6 +229,35 @@ async def get_context_value(request: Request, _):
Callable[[Optional[Any], Optional[str], Optional[dict], DocumentNode], Any],
]


class BaseProxyRootValue:
"""A `RootValue` wrapper that includes result JSON update logic.
Can be returned by the `RootValue` callable. Not used by Ariadne directly
but part of the support for Ariadne GraphQL Proxy.
# Attributes
- `root_value: Optional[dict]`: `RootValue` to use during query execution.
"""

__slots__ = ("root_value",)

root_value: Optional[dict]

def __init__(self, root_value: Optional[dict] = None):
self.root_value = root_value

def update_result(self, result: GraphQLResult) -> GraphQLResult:
"""An update function used to create a final `GraphQL` result tuple to
create a JSON response from.
Default implementation in `BaseProxyRootValue` is a passthrough that
returns `result` value without any changes.
"""
return result


"""Type of `query_parser` option of GraphQL servers.
Enables customization of server's GraphQL parsing logic. If not set or `None`,
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def type_defs():
testContext: String
testRoot: String
testError: Boolean
context: String
}
type Mutation {
Expand Down
36 changes: 35 additions & 1 deletion tests/test_graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from graphql.validation.rules import ValidationRule

from ariadne import graphql, graphql_sync, subscribe
from ariadne.types import BaseProxyRootValue


class AlwaysInvalid(ValidationRule):
Expand All @@ -12,6 +13,12 @@ def leave_operation_definition( # pylint: disable=unused-argument
self.context.report_error(GraphQLError("Invalid"))


class ProxyRootValue(BaseProxyRootValue):
def update_result(self, result):
success, data = result
return success, {"updated": True, **data}


def test_graphql_sync_executes_the_query(schema):
success, result = graphql_sync(schema, {"query": '{ hello(name: "world") }'})
assert success
Expand Down Expand Up @@ -51,8 +58,21 @@ def test_graphql_sync_prevents_introspection_query_when_option_is_disabled(schem
)


def test_graphql_sync_executes_the_query_using_result_update_obj(schema):
success, result = graphql_sync(
schema,
{"query": "{ context }"},
root_value=ProxyRootValue({"context": "Works!"}),
)
assert success
assert result == {
"data": {"context": "Works!"},
"updated": True,
}


@pytest.mark.asyncio
async def test_graphql_execute_the_query(schema):
async def test_graphql_executes_the_query(schema):
success, result = await graphql(schema, {"query": '{ hello(name: "world") }'})
assert success
assert result["data"] == {"hello": "Hello, world!"}
Expand Down Expand Up @@ -94,6 +114,20 @@ async def test_graphql_prevents_introspection_query_when_option_is_disabled(sche
)


@pytest.mark.asyncio
async def test_graphql_executes_the_query_using_result_update_obj(schema):
success, result = await graphql(
schema,
{"query": "{ context }"},
root_value=ProxyRootValue({"context": "Works!"}),
)
assert success
assert result == {
"data": {"context": "Works!"},
"updated": True,
}


@pytest.mark.asyncio
async def test_subscription_returns_an_async_iterator(schema):
success, result = await subscribe(schema, {"query": "subscription { ping }"})
Expand Down

0 comments on commit 94ff3de

Please sign in to comment.