Skip to content

Commit

Permalink
Execute queries over GET (#1160)
Browse files Browse the repository at this point in the history
  • Loading branch information
rafalp authored Feb 19, 2024
1 parent de79684 commit f4925c6
Show file tree
Hide file tree
Showing 8 changed files with 476 additions and 24 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# CHANGELOG

## 0.23 (UNRELEASED)

- Added `execute_get_queries` setting to the `GraphQL` apps that controls execution of the GraphQL "query" operations made with GET requests. Defaults to `False`.
- Added support for the Apollo Federation versions up to 2.6.


## 0.22 (2024-01-31)

- Deprecated `EnumType.bind_to_default_values` method. It will be removed in a future release.
Expand Down
6 changes: 6 additions & 0 deletions ariadne/asgi/graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
query_parser: Optional[QueryParser] = None,
query_validator: Optional[QueryValidator] = None,
validation_rules: Optional[ValidationRules] = None,
execute_get_queries: bool = False,
debug: bool = False,
introspection: bool = True,
explorer: Optional[Explorer] = None,
Expand Down Expand Up @@ -73,6 +74,9 @@ def __init__(
list of extra validation rules server should use to validate the
GraphQL queries. Defaults to `None`.
`execute_get_queries`: a `bool` that controls if `query` operations
sent using the `GET` method should be executed. Defaults to `False`.
`debug`: a `bool` controlling in server should run in debug mode or
not. Controls details included in error data returned to clients.
Defaults to `False`.
Expand Down Expand Up @@ -126,6 +130,7 @@ def __init__(
query_parser,
query_validator,
validation_rules,
execute_get_queries,
debug,
introspection,
explorer,
Expand All @@ -140,6 +145,7 @@ def __init__(
query_parser,
query_validator,
validation_rules,
execute_get_queries,
debug,
introspection,
explorer,
Expand Down
3 changes: 3 additions & 0 deletions ariadne/asgi/handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self) -> None:
self.query_parser: Optional[QueryParser] = None
self.query_validator: Optional[QueryValidator] = None
self.validation_rules: Optional[ValidationRules] = None
self.execute_get_queries: bool = False
self.execution_context_class: Optional[Type[ExecutionContext]] = None
self.middleware_manager_class: Optional[Type[MiddlewareManager]] = None

Expand Down Expand Up @@ -79,6 +80,7 @@ def configure(
query_parser: Optional[QueryParser] = None,
query_validator: Optional[QueryValidator] = None,
validation_rules: Optional[ValidationRules] = None,
execute_get_queries: bool = False,
debug: bool = False,
introspection: bool = True,
explorer: Optional[Explorer] = None,
Expand All @@ -94,6 +96,7 @@ def configure(
self.context_value = context_value
self.debug = debug
self.error_formatter = error_formatter
self.execute_get_queries = execute_get_queries
self.execution_context_class = execution_context_class
self.introspection = introspection
self.explorer = explorer
Expand Down
58 changes: 52 additions & 6 deletions ariadne/asgi/handlers/http.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from inspect import isawaitable
from typing import Any, Optional, Type, cast
from typing import Any, Optional, Type, Union, cast

from graphql import DocumentNode, MiddlewareManager
from starlette.datastructures import UploadFile
Expand Down Expand Up @@ -114,9 +114,12 @@ async def handle_request(self, request: Request) -> Response:
`request`: the `Request` instance from Starlette or FastAPI.
"""
if request.method == "GET" and self.introspection and self.explorer:
# only render explorer when introspection is enabled
return await self.render_explorer(request, self.explorer)
if request.method == "GET":
if self.execute_get_queries and request.query_params.get("query"):
return await self.graphql_http_server(request)
if self.introspection and self.explorer:
# only render explorer when introspection is enabled
return await self.render_explorer(request, self.explorer)

if request.method == "POST":
return await self.graphql_http_server(request)
Expand Down Expand Up @@ -182,14 +185,20 @@ async def extract_data_from_request(self, request: Request):
return await self.extract_data_from_json_request(request)
if content_type == DATA_TYPE_MULTIPART:
return await self.extract_data_from_multipart_request(request)
if (
request.method == "GET"
and self.execute_get_queries
and request.query_params.get("query")
):
return self.extract_data_from_get_request(request)

raise HttpBadRequestError(
"Posted content must be of type {} or {}".format(
DATA_TYPE_JSON, DATA_TYPE_MULTIPART
)
)

async def extract_data_from_json_request(self, request: Request):
async def extract_data_from_json_request(self, request: Request) -> dict:
"""Extracts GraphQL data from JSON request.
Returns a `dict` with GraphQL query data that was not yet validated.
Expand All @@ -203,7 +212,9 @@ async def extract_data_from_json_request(self, request: Request):
except (TypeError, ValueError) as ex:
raise HttpBadRequestError("Request body is not a valid JSON") from ex

async def extract_data_from_multipart_request(self, request: Request):
async def extract_data_from_multipart_request(
self, request: Request
) -> Union[dict, list]:
"""Extracts GraphQL data from `multipart/form-data` request.
Returns an unvalidated `dict` with GraphQL query data.
Expand Down Expand Up @@ -240,6 +251,35 @@ async def extract_data_from_multipart_request(self, request: Request):

return combine_multipart_data(operations, files_map, request_files)

def extract_data_from_get_request(self, request: Request) -> dict:
"""Extracts GraphQL data from GET request's querystring.
Returns a `dict` with GraphQL query data that was not yet validated.
# Required arguments
`request`: the `Request` instance from Starlette or FastAPI.
"""
query = request.query_params["query"].strip()
operation_name = request.query_params.get("operationName", "").strip()
variables = request.query_params.get("variables", "").strip()

clean_variables = None

if variables:
try:
clean_variables = json.loads(variables)
except (TypeError, ValueError) as ex:
raise HttpBadRequestError(
"Variables query arg is not a valid JSON"
) from ex

return {
"query": query,
"operationName": operation_name or None,
"variables": clean_variables,
}

async def execute_graphql_query(
self,
request: Any,
Expand Down Expand Up @@ -275,6 +315,11 @@ async def execute_graphql_query(
if self.schema is None:
raise TypeError("schema is not set, call configure method to initialize it")

if isinstance(request, Request):
require_query = request.method == "GET"
else:
require_query = False

return await graphql(
self.schema,
data,
Expand All @@ -284,6 +329,7 @@ async def execute_graphql_query(
query_validator=self.query_validator,
query_document=query_document,
validation_rules=self.validation_rules,
require_query=require_query,
debug=self.debug,
introspection=self.introspection,
logger=self.logger,
Expand Down
41 changes: 41 additions & 0 deletions ariadne/graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
GraphQLError,
GraphQLSchema,
MiddlewareManager,
OperationDefinitionNode,
TypeInfo,
execute,
execute_sync,
Expand Down Expand Up @@ -71,6 +72,7 @@ async def graphql(
introspection: bool = True,
logger: Union[None, str, Logger, LoggerAdapter] = None,
validation_rules: Optional[ValidationRules] = None,
require_query: bool = False,
error_formatter: ErrorFormatter = format_error,
middleware: MiddlewareList = None,
middleware_manager_class: Optional[Type[MiddlewareManager]] = None,
Expand Down Expand Up @@ -123,6 +125,9 @@ async def graphql(
`validation_rules`: a `list` of or callable returning list of custom
validation rules to use to validate query before it's executed.
`require_query`: a `bool` controlling if GraphQL operation to execute must be
a query (vs. mutation or subscription).
`error_formatter`: an `ErrorFormatter` callable to use to convert GraphQL
errors encountered during query execution to JSON-serializable format.
Expand Down Expand Up @@ -178,6 +183,9 @@ async def graphql(
extension_manager=extension_manager,
)

if require_query:
validate_operation_is_query(document, operation_name)

if callable(root_value):
try:
root_value = root_value( # type: ignore
Expand Down Expand Up @@ -237,6 +245,7 @@ def graphql_sync(
introspection: bool = True,
logger: Union[None, str, Logger, LoggerAdapter] = None,
validation_rules: Optional[ValidationRules] = None,
require_query: bool = False,
error_formatter: ErrorFormatter = format_error,
middleware: MiddlewareList = None,
middleware_manager_class: Optional[Type[MiddlewareManager]] = None,
Expand Down Expand Up @@ -289,6 +298,9 @@ def graphql_sync(
`validation_rules`: a `list` of or callable returning list of custom
validation rules to use to validate query before it's executed.
`require_query`: a `bool` controlling if GraphQL operation to execute must be
a query (vs. mutation or subscription).
`error_formatter`: an `ErrorFormatter` callable to use to convert GraphQL
errors encountered during query execution to JSON-serializable format.
Expand Down Expand Up @@ -344,6 +356,9 @@ def graphql_sync(
extension_manager=extension_manager,
)

if require_query:
validate_operation_is_query(document, operation_name)

if callable(root_value):
try:
root_value = root_value( # type: ignore
Expand Down Expand Up @@ -639,3 +654,29 @@ def validate_variables(variables) -> None:
def validate_operation_name(operation_name) -> None:
if operation_name is not None and not isinstance(operation_name, str):
raise GraphQLError('"%s" is not a valid operation name.' % operation_name)


def validate_operation_is_query(
document_ast: DocumentNode, operation_name: Optional[str]
):
query_operations: List[Optional[str]] = []
for definition in document_ast.definitions:
if (
isinstance(definition, OperationDefinitionNode)
and definition.operation.name == "QUERY"
):
if definition.name:
query_operations.append(definition.name.value)
else:
query_operations.append(None)

if operation_name:
if operation_name not in query_operations:
raise GraphQLError(
f"Operation '{operation_name}' is not defined or "
"is not of a 'query' type."
)
elif len(query_operations) != 1:
raise GraphQLError(
"'operationName' is required if 'query' defines multiple operations."
)
Loading

0 comments on commit f4925c6

Please sign in to comment.