Skip to content

Commit

Permalink
incremental: subsequent result records should not store parent refere…
Browse files Browse the repository at this point in the history
…nces

Replicates graphql/graphql-js@fae5da5
  • Loading branch information
Cito committed Sep 14, 2024
1 parent 448d045 commit eb9edd5
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 133 deletions.
3 changes: 3 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,17 @@
GraphQLTypeResolver
GroupedFieldSet
IncrementalDataRecord
InitialResultRecord
Middleware
SubsequentDataRecord
asyncio.events.AbstractEventLoop
graphql.execution.collect_fields.FieldsAndPatches
graphql.execution.map_async_iterable.map_async_iterable
graphql.execution.Middleware
graphql.execution.execute.ExperimentalIncrementalExecutionResults
graphql.execution.execute.StreamArguments
graphql.execution.incremental_publisher.IncrementalPublisher
graphql.execution.incremental_publisher.InitialResultRecord
graphql.execution.incremental_publisher.StreamItemsRecord
graphql.execution.incremental_publisher.DeferredFragmentRecord
graphql.language.lexer.EscapeSequence
Expand Down
92 changes: 51 additions & 41 deletions src/graphql/execution/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@
IncrementalDataRecord,
IncrementalPublisher,
IncrementalResult,
InitialResultRecord,
StreamItemsRecord,
SubsequentDataRecord,
SubsequentIncrementalExecutionResult,
)
from .middleware import MiddlewareManager
Expand Down Expand Up @@ -352,7 +354,6 @@ class ExecutionContext:
field_resolver: GraphQLFieldResolver
type_resolver: GraphQLTypeResolver
subscribe_field_resolver: GraphQLFieldResolver
errors: list[GraphQLError]
incremental_publisher: IncrementalPublisher
middleware_manager: MiddlewareManager | None

Expand All @@ -371,7 +372,6 @@ def __init__(
field_resolver: GraphQLFieldResolver,
type_resolver: GraphQLTypeResolver,
subscribe_field_resolver: GraphQLFieldResolver,
errors: list[GraphQLError],
incremental_publisher: IncrementalPublisher,
middleware_manager: MiddlewareManager | None,
is_awaitable: Callable[[Any], bool] | None,
Expand All @@ -385,7 +385,6 @@ def __init__(
self.field_resolver = field_resolver
self.type_resolver = type_resolver
self.subscribe_field_resolver = subscribe_field_resolver
self.errors = errors
self.incremental_publisher = incremental_publisher
self.middleware_manager = middleware_manager
if is_awaitable:
Expand Down Expand Up @@ -478,7 +477,6 @@ def build(
field_resolver or default_field_resolver,
type_resolver or default_type_resolver,
subscribe_field_resolver or default_field_resolver,
[],
IncrementalPublisher(),
middleware_manager,
is_awaitable,
Expand Down Expand Up @@ -514,15 +512,14 @@ def build_per_event_execution_context(self, payload: Any) -> ExecutionContext:
self.field_resolver,
self.type_resolver,
self.subscribe_field_resolver,
[],
# no need to update incrementalPublisher,
# incremental delivery is not supported for subscriptions
self.incremental_publisher,
self.middleware_manager,
self.is_awaitable,
)

def execute_operation(self) -> AwaitableOrValue[dict[str, Any]]:
def execute_operation(
self, initial_result_record: InitialResultRecord
) -> AwaitableOrValue[dict[str, Any]]:
"""Execute an operation.
Implements the "Executing operations" section of the spec.
Expand Down Expand Up @@ -551,12 +548,17 @@ def execute_operation(self) -> AwaitableOrValue[dict[str, Any]]:
self.execute_fields_serially
if operation.operation == OperationType.MUTATION
else self.execute_fields
)(root_type, root_value, None, grouped_field_set) # type: ignore
)(root_type, root_value, None, grouped_field_set, initial_result_record)

for patch in patches:
label, patch_grouped_filed_set = patch
self.execute_deferred_fragment(
root_type, root_value, patch_grouped_filed_set, label, None
root_type,
root_value,
patch_grouped_filed_set,
initial_result_record,
label,
None,
)

return result
Expand All @@ -567,6 +569,7 @@ def execute_fields_serially(
source_value: Any,
path: Path | None,
grouped_field_set: GroupedFieldSet,
incremental_data_record: IncrementalDataRecord,
) -> AwaitableOrValue[dict[str, Any]]:
"""Execute the given fields serially.
Expand All @@ -581,7 +584,11 @@ def reducer(
response_name, field_group = field_item
field_path = Path(path, response_name, parent_type.name)
result = self.execute_field(
parent_type, source_value, field_group, field_path
parent_type,
source_value,
field_group,
field_path,
incremental_data_record,
)
if result is Undefined:
return results
Expand All @@ -607,7 +614,7 @@ def execute_fields(
source_value: Any,
path: Path | None,
grouped_field_set: GroupedFieldSet,
incremental_data_record: IncrementalDataRecord | None = None,
incremental_data_record: IncrementalDataRecord,
) -> AwaitableOrValue[dict[str, Any]]:
"""Execute the given fields concurrently.
Expand Down Expand Up @@ -662,7 +669,7 @@ def execute_field(
source: Any,
field_group: FieldGroup,
path: Path,
incremental_data_record: IncrementalDataRecord | None = None,
incremental_data_record: IncrementalDataRecord,
) -> AwaitableOrValue[Any]:
"""Resolve the field on the given source object.
Expand Down Expand Up @@ -774,7 +781,7 @@ def handle_field_error(
return_type: GraphQLOutputType,
field_group: FieldGroup,
path: Path,
incremental_data_record: IncrementalDataRecord | None = None,
incremental_data_record: IncrementalDataRecord,
) -> None:
"""Handle error properly according to the field type."""
error = located_error(raw_error, field_group, path.as_list())
Expand All @@ -784,13 +791,9 @@ def handle_field_error(
if is_non_null_type(return_type):
raise error

errors = (
incremental_data_record.errors if incremental_data_record else self.errors
)

# Otherwise, error protection is applied, logging the error and resolving a
# null value for this field if one is encountered.
errors.append(error)
self.incremental_publisher.add_field_error(incremental_data_record, error)

def complete_value(
self,
Expand All @@ -799,7 +802,7 @@ def complete_value(
info: GraphQLResolveInfo,
path: Path,
result: Any,
incremental_data_record: IncrementalDataRecord | None,
incremental_data_record: IncrementalDataRecord,
) -> AwaitableOrValue[Any]:
"""Complete a value.
Expand Down Expand Up @@ -888,7 +891,7 @@ async def complete_awaitable_value(
info: GraphQLResolveInfo,
path: Path,
result: Any,
incremental_data_record: IncrementalDataRecord | None = None,
incremental_data_record: IncrementalDataRecord,
) -> Any:
"""Complete an awaitable value."""
try:
Expand Down Expand Up @@ -955,7 +958,7 @@ async def complete_async_iterator_value(
info: GraphQLResolveInfo,
path: Path,
async_iterator: AsyncIterator[Any],
incremental_data_record: IncrementalDataRecord | None,
incremental_data_record: IncrementalDataRecord,
) -> list[Any]:
"""Complete an async iterator.
Expand Down Expand Up @@ -984,8 +987,8 @@ async def complete_async_iterator_value(
info,
item_type,
path,
stream.label,
incremental_data_record,
stream.label,
)
),
timeout=ASYNC_DELAY,
Expand Down Expand Up @@ -1039,7 +1042,7 @@ def complete_list_value(
info: GraphQLResolveInfo,
path: Path,
result: AsyncIterable[Any] | Iterable[Any],
incremental_data_record: IncrementalDataRecord | None,
incremental_data_record: IncrementalDataRecord,
) -> AwaitableOrValue[list[Any]]:
"""Complete a list value.
Expand Down Expand Up @@ -1093,8 +1096,8 @@ def complete_list_value(
field_group,
info,
item_type,
stream.label,
previous_incremental_data_record,
stream.label,
)
continue

Expand Down Expand Up @@ -1138,7 +1141,7 @@ def complete_list_item_value(
field_group: FieldGroup,
info: GraphQLResolveInfo,
item_path: Path,
incremental_data_record: IncrementalDataRecord | None,
incremental_data_record: IncrementalDataRecord,
) -> bool:
"""Complete a list item value by adding it to the completed results.
Expand Down Expand Up @@ -1229,7 +1232,7 @@ def complete_abstract_value(
info: GraphQLResolveInfo,
path: Path,
result: Any,
incremental_data_record: IncrementalDataRecord | None,
incremental_data_record: IncrementalDataRecord,
) -> AwaitableOrValue[Any]:
"""Complete an abstract value.
Expand Down Expand Up @@ -1344,7 +1347,7 @@ def complete_object_value(
info: GraphQLResolveInfo,
path: Path,
result: Any,
incremental_data_record: IncrementalDataRecord | None,
incremental_data_record: IncrementalDataRecord,
) -> AwaitableOrValue[dict[str, Any]]:
"""Complete an Object value by executing all sub-selections."""
# If there is an `is_type_of()` predicate function, call it with the current
Expand Down Expand Up @@ -1379,7 +1382,7 @@ def collect_and_execute_subfields(
field_group: FieldGroup,
path: Path,
result: Any,
incremental_data_record: IncrementalDataRecord | None,
incremental_data_record: IncrementalDataRecord,
) -> AwaitableOrValue[dict[str, Any]]:
"""Collect sub-fields to execute to complete this value."""
sub_grouped_field_set, sub_patches = self.collect_subfields(
Expand All @@ -1396,9 +1399,9 @@ def collect_and_execute_subfields(
return_type,
result,
sub_patch_grouped_field_set,
incremental_data_record,
label,
path,
incremental_data_record,
)

return sub_fields
Expand Down Expand Up @@ -1474,9 +1477,9 @@ def execute_deferred_fragment(
parent_type: GraphQLObjectType,
source_value: Any,
fields: GroupedFieldSet,
parent_context: IncrementalDataRecord,
label: str | None = None,
path: Path | None = None,
parent_context: IncrementalDataRecord | None = None,
) -> None:
"""Execute deferred fragment."""
incremental_publisher = self.incremental_publisher
Expand Down Expand Up @@ -1529,9 +1532,9 @@ def execute_stream_field(
field_group: FieldGroup,
info: GraphQLResolveInfo,
item_type: GraphQLOutputType,
parent_context: IncrementalDataRecord,
label: str | None = None,
parent_context: IncrementalDataRecord | None = None,
) -> IncrementalDataRecord:
) -> SubsequentDataRecord:
"""Execute stream field."""
is_awaitable = self.is_awaitable
incremental_publisher = self.incremental_publisher
Expand Down Expand Up @@ -1678,8 +1681,8 @@ async def execute_stream_async_iterator(
info: GraphQLResolveInfo,
item_type: GraphQLOutputType,
path: Path,
parent_context: IncrementalDataRecord,
label: str | None = None,
parent_context: IncrementalDataRecord | None = None,
) -> None:
"""Execute stream iterator."""
incremental_publisher = self.incremental_publisher
Expand Down Expand Up @@ -1877,21 +1880,24 @@ def execute_impl(
# Errors from sub-fields of a NonNull type may propagate to the top level,
# at which point we still log the error and null the parent field, which
# in this case is the entire response.
errors = context.errors
incremental_publisher = context.incremental_publisher
initial_result_record = incremental_publisher.prepare_initial_result_record()
build_response = context.build_response
try:
result = context.execute_operation()
result = context.execute_operation(initial_result_record)

if context.is_awaitable(result):
# noinspection PyShadowingNames
async def await_result() -> Any:
try:
errors = incremental_publisher.get_initial_errors(
initial_result_record
)
initial_result = build_response(
await result, # type: ignore
errors,
)
incremental_publisher.publish_initial()
incremental_publisher.publish_initial(initial_result_record)
if incremental_publisher.has_next():
return ExperimentalIncrementalExecutionResults(
initial_result=InitialIncrementalExecutionResult(
Expand All @@ -1902,14 +1908,17 @@ async def await_result() -> Any:
subsequent_results=incremental_publisher.subscribe(),
)
except GraphQLError as error:
errors.append(error)
incremental_publisher.add_field_error(initial_result_record, error)
errors = incremental_publisher.get_initial_errors(
initial_result_record
)
return build_response(None, errors)
return initial_result

return await_result()

initial_result = build_response(result, errors) # type: ignore
incremental_publisher.publish_initial()
initial_result = build_response(result, initial_result_record.errors) # type: ignore
incremental_publisher.publish_initial(initial_result_record)
if incremental_publisher.has_next():
return ExperimentalIncrementalExecutionResults(
initial_result=InitialIncrementalExecutionResult(
Expand All @@ -1920,7 +1929,8 @@ async def await_result() -> Any:
subsequent_results=incremental_publisher.subscribe(),
)
except GraphQLError as error:
errors.append(error)
incremental_publisher.add_field_error(initial_result_record, error)
errors = incremental_publisher.get_initial_errors(initial_result_record)
return build_response(None, errors)
return initial_result

Expand Down
Loading

0 comments on commit eb9edd5

Please sign in to comment.