diff --git a/docs/conf.py b/docs/conf.py index bd53efa0..43766c1b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -161,7 +161,9 @@ GraphQLTypeResolver GroupedFieldSet IncrementalDataRecord +InitialResultRecord Middleware +SubsequentDataRecord asyncio.events.AbstractEventLoop graphql.execution.collect_fields.FieldsAndPatches graphql.execution.map_async_iterable.map_async_iterable @@ -169,6 +171,7 @@ graphql.execution.execute.ExperimentalIncrementalExecutionResults graphql.execution.execute.StreamArguments graphql.execution.incremental_publisher.IncrementalPublisher +graphql.execution.incremental_publisher.InitialResultRecord graphql.execution.incremental_publisher.StreamItemsRecord graphql.execution.incremental_publisher.DeferredFragmentRecord graphql.language.lexer.EscapeSequence diff --git a/src/graphql/execution/execute.py b/src/graphql/execution/execute.py index ae56c9b9..d61909a9 100644 --- a/src/graphql/execution/execute.py +++ b/src/graphql/execution/execute.py @@ -86,7 +86,9 @@ IncrementalDataRecord, IncrementalPublisher, IncrementalResult, + InitialResultRecord, StreamItemsRecord, + SubsequentDataRecord, SubsequentIncrementalExecutionResult, ) from .middleware import MiddlewareManager @@ -352,7 +354,6 @@ class ExecutionContext: field_resolver: GraphQLFieldResolver type_resolver: GraphQLTypeResolver subscribe_field_resolver: GraphQLFieldResolver - errors: list[GraphQLError] incremental_publisher: IncrementalPublisher middleware_manager: MiddlewareManager | None @@ -371,7 +372,6 @@ def __init__( field_resolver: GraphQLFieldResolver, type_resolver: GraphQLTypeResolver, subscribe_field_resolver: GraphQLFieldResolver, - errors: list[GraphQLError], incremental_publisher: IncrementalPublisher, middleware_manager: MiddlewareManager | None, is_awaitable: Callable[[Any], bool] | None, @@ -385,7 +385,6 @@ def __init__( self.field_resolver = field_resolver self.type_resolver = type_resolver self.subscribe_field_resolver = subscribe_field_resolver - self.errors = errors self.incremental_publisher = incremental_publisher self.middleware_manager = middleware_manager if is_awaitable: @@ -478,7 +477,6 @@ def build( field_resolver or default_field_resolver, type_resolver or default_type_resolver, subscribe_field_resolver or default_field_resolver, - [], IncrementalPublisher(), middleware_manager, is_awaitable, @@ -514,15 +512,14 @@ def build_per_event_execution_context(self, payload: Any) -> ExecutionContext: self.field_resolver, self.type_resolver, self.subscribe_field_resolver, - [], - # no need to update incrementalPublisher, - # incremental delivery is not supported for subscriptions self.incremental_publisher, self.middleware_manager, self.is_awaitable, ) - def execute_operation(self) -> AwaitableOrValue[dict[str, Any]]: + def execute_operation( + self, initial_result_record: InitialResultRecord + ) -> AwaitableOrValue[dict[str, Any]]: """Execute an operation. Implements the "Executing operations" section of the spec. @@ -551,12 +548,17 @@ def execute_operation(self) -> AwaitableOrValue[dict[str, Any]]: self.execute_fields_serially if operation.operation == OperationType.MUTATION else self.execute_fields - )(root_type, root_value, None, grouped_field_set) # type: ignore + )(root_type, root_value, None, grouped_field_set, initial_result_record) for patch in patches: label, patch_grouped_filed_set = patch self.execute_deferred_fragment( - root_type, root_value, patch_grouped_filed_set, label, None + root_type, + root_value, + patch_grouped_filed_set, + initial_result_record, + label, + None, ) return result @@ -567,6 +569,7 @@ def execute_fields_serially( source_value: Any, path: Path | None, grouped_field_set: GroupedFieldSet, + incremental_data_record: IncrementalDataRecord, ) -> AwaitableOrValue[dict[str, Any]]: """Execute the given fields serially. @@ -581,7 +584,11 @@ def reducer( response_name, field_group = field_item field_path = Path(path, response_name, parent_type.name) result = self.execute_field( - parent_type, source_value, field_group, field_path + parent_type, + source_value, + field_group, + field_path, + incremental_data_record, ) if result is Undefined: return results @@ -607,7 +614,7 @@ def execute_fields( source_value: Any, path: Path | None, grouped_field_set: GroupedFieldSet, - incremental_data_record: IncrementalDataRecord | None = None, + incremental_data_record: IncrementalDataRecord, ) -> AwaitableOrValue[dict[str, Any]]: """Execute the given fields concurrently. @@ -662,7 +669,7 @@ def execute_field( source: Any, field_group: FieldGroup, path: Path, - incremental_data_record: IncrementalDataRecord | None = None, + incremental_data_record: IncrementalDataRecord, ) -> AwaitableOrValue[Any]: """Resolve the field on the given source object. @@ -774,7 +781,7 @@ def handle_field_error( return_type: GraphQLOutputType, field_group: FieldGroup, path: Path, - incremental_data_record: IncrementalDataRecord | None = None, + incremental_data_record: IncrementalDataRecord, ) -> None: """Handle error properly according to the field type.""" error = located_error(raw_error, field_group, path.as_list()) @@ -784,13 +791,9 @@ def handle_field_error( if is_non_null_type(return_type): raise error - errors = ( - incremental_data_record.errors if incremental_data_record else self.errors - ) - # Otherwise, error protection is applied, logging the error and resolving a # null value for this field if one is encountered. - errors.append(error) + self.incremental_publisher.add_field_error(incremental_data_record, error) def complete_value( self, @@ -799,7 +802,7 @@ def complete_value( info: GraphQLResolveInfo, path: Path, result: Any, - incremental_data_record: IncrementalDataRecord | None, + incremental_data_record: IncrementalDataRecord, ) -> AwaitableOrValue[Any]: """Complete a value. @@ -888,7 +891,7 @@ async def complete_awaitable_value( info: GraphQLResolveInfo, path: Path, result: Any, - incremental_data_record: IncrementalDataRecord | None = None, + incremental_data_record: IncrementalDataRecord, ) -> Any: """Complete an awaitable value.""" try: @@ -955,7 +958,7 @@ async def complete_async_iterator_value( info: GraphQLResolveInfo, path: Path, async_iterator: AsyncIterator[Any], - incremental_data_record: IncrementalDataRecord | None, + incremental_data_record: IncrementalDataRecord, ) -> list[Any]: """Complete an async iterator. @@ -984,8 +987,8 @@ async def complete_async_iterator_value( info, item_type, path, - stream.label, incremental_data_record, + stream.label, ) ), timeout=ASYNC_DELAY, @@ -1039,7 +1042,7 @@ def complete_list_value( info: GraphQLResolveInfo, path: Path, result: AsyncIterable[Any] | Iterable[Any], - incremental_data_record: IncrementalDataRecord | None, + incremental_data_record: IncrementalDataRecord, ) -> AwaitableOrValue[list[Any]]: """Complete a list value. @@ -1093,8 +1096,8 @@ def complete_list_value( field_group, info, item_type, - stream.label, previous_incremental_data_record, + stream.label, ) continue @@ -1138,7 +1141,7 @@ def complete_list_item_value( field_group: FieldGroup, info: GraphQLResolveInfo, item_path: Path, - incremental_data_record: IncrementalDataRecord | None, + incremental_data_record: IncrementalDataRecord, ) -> bool: """Complete a list item value by adding it to the completed results. @@ -1229,7 +1232,7 @@ def complete_abstract_value( info: GraphQLResolveInfo, path: Path, result: Any, - incremental_data_record: IncrementalDataRecord | None, + incremental_data_record: IncrementalDataRecord, ) -> AwaitableOrValue[Any]: """Complete an abstract value. @@ -1344,7 +1347,7 @@ def complete_object_value( info: GraphQLResolveInfo, path: Path, result: Any, - incremental_data_record: IncrementalDataRecord | None, + incremental_data_record: IncrementalDataRecord, ) -> AwaitableOrValue[dict[str, Any]]: """Complete an Object value by executing all sub-selections.""" # If there is an `is_type_of()` predicate function, call it with the current @@ -1379,7 +1382,7 @@ def collect_and_execute_subfields( field_group: FieldGroup, path: Path, result: Any, - incremental_data_record: IncrementalDataRecord | None, + incremental_data_record: IncrementalDataRecord, ) -> AwaitableOrValue[dict[str, Any]]: """Collect sub-fields to execute to complete this value.""" sub_grouped_field_set, sub_patches = self.collect_subfields( @@ -1396,9 +1399,9 @@ def collect_and_execute_subfields( return_type, result, sub_patch_grouped_field_set, + incremental_data_record, label, path, - incremental_data_record, ) return sub_fields @@ -1474,9 +1477,9 @@ def execute_deferred_fragment( parent_type: GraphQLObjectType, source_value: Any, fields: GroupedFieldSet, + parent_context: IncrementalDataRecord, label: str | None = None, path: Path | None = None, - parent_context: IncrementalDataRecord | None = None, ) -> None: """Execute deferred fragment.""" incremental_publisher = self.incremental_publisher @@ -1529,9 +1532,9 @@ def execute_stream_field( field_group: FieldGroup, info: GraphQLResolveInfo, item_type: GraphQLOutputType, + parent_context: IncrementalDataRecord, label: str | None = None, - parent_context: IncrementalDataRecord | None = None, - ) -> IncrementalDataRecord: + ) -> SubsequentDataRecord: """Execute stream field.""" is_awaitable = self.is_awaitable incremental_publisher = self.incremental_publisher @@ -1678,8 +1681,8 @@ async def execute_stream_async_iterator( info: GraphQLResolveInfo, item_type: GraphQLOutputType, path: Path, + parent_context: IncrementalDataRecord, label: str | None = None, - parent_context: IncrementalDataRecord | None = None, ) -> None: """Execute stream iterator.""" incremental_publisher = self.incremental_publisher @@ -1877,21 +1880,24 @@ def execute_impl( # Errors from sub-fields of a NonNull type may propagate to the top level, # at which point we still log the error and null the parent field, which # in this case is the entire response. - errors = context.errors incremental_publisher = context.incremental_publisher + initial_result_record = incremental_publisher.prepare_initial_result_record() build_response = context.build_response try: - result = context.execute_operation() + result = context.execute_operation(initial_result_record) if context.is_awaitable(result): # noinspection PyShadowingNames async def await_result() -> Any: try: + errors = incremental_publisher.get_initial_errors( + initial_result_record + ) initial_result = build_response( await result, # type: ignore errors, ) - incremental_publisher.publish_initial() + incremental_publisher.publish_initial(initial_result_record) if incremental_publisher.has_next(): return ExperimentalIncrementalExecutionResults( initial_result=InitialIncrementalExecutionResult( @@ -1902,14 +1908,17 @@ async def await_result() -> Any: subsequent_results=incremental_publisher.subscribe(), ) except GraphQLError as error: - errors.append(error) + incremental_publisher.add_field_error(initial_result_record, error) + errors = incremental_publisher.get_initial_errors( + initial_result_record + ) return build_response(None, errors) return initial_result return await_result() - initial_result = build_response(result, errors) # type: ignore - incremental_publisher.publish_initial() + initial_result = build_response(result, initial_result_record.errors) # type: ignore + incremental_publisher.publish_initial(initial_result_record) if incremental_publisher.has_next(): return ExperimentalIncrementalExecutionResults( initial_result=InitialIncrementalExecutionResult( @@ -1920,7 +1929,8 @@ async def await_result() -> Any: subsequent_results=incremental_publisher.subscribe(), ) except GraphQLError as error: - errors.append(error) + incremental_publisher.add_field_error(initial_result_record, error) + errors = incremental_publisher.get_initial_errors(initial_result_record) return build_response(None, errors) return initial_result diff --git a/src/graphql/execution/incremental_publisher.py b/src/graphql/execution/incremental_publisher.py index fb660e85..bf145da3 100644 --- a/src/graphql/execution/incremental_publisher.py +++ b/src/graphql/execution/incremental_publisher.py @@ -33,6 +33,7 @@ "FormattedIncrementalResult", "FormattedIncrementalStreamResult", "FormattedSubsequentIncrementalExecutionResult", + "InitialResultRecord", "IncrementalDataRecord", "IncrementalDeferResult", "IncrementalPublisher", @@ -340,34 +341,23 @@ class IncrementalPublisher: The internal publishing state is managed as follows: - ``_released``: the set of Incremental Data records that are ready to be sent to the + ``_released``: the set of Subsequent Data records that are ready to be sent to the client, i.e. their parents have completed and they have also completed. - ``_pending``: the set of Incremental Data records that are definitely pending, i.e. + ``_pending``: the set of Subsequent Data records that are definitely pending, i.e. their parents have completed so that they can no longer be filtered. This includes - all Incremental Data records in `released`, as well as Incremental Data records that + all Subsequent Data records in `released`, as well as Subsequent Data records that have not yet completed. - ``_initial_result``: a record containing the state of the initial result, - as follows: - ``is_completed``: indicates whether the initial result has completed. - ``children``: the set of Incremental Data records that can be be published when the - initial result is completed. - - Each Incremental Data record also contains similar metadata, i.e. these records also - contain similar ``is_completed`` and ``children`` properties. - Note: Instead of sets we use dicts (with values set to None) which preserve order and thereby achieve more deterministic results. """ - _initial_result: InitialResult - _released: dict[IncrementalDataRecord, None] - _pending: dict[IncrementalDataRecord, None] + _released: dict[SubsequentDataRecord, None] + _pending: dict[SubsequentDataRecord, None] _resolve: Event | None def __init__(self) -> None: - self._initial_result = InitialResult({}, False) self._released = {} self._pending = {} self._resolve = None # lazy initialization @@ -420,33 +410,33 @@ async def subscribe( close_async_iterators.append(close_async_iterator) await gather(*close_async_iterators) + def prepare_initial_result_record(self) -> InitialResultRecord: + """Prepare a new initial result record.""" + return InitialResultRecord(errors=[], children={}) + def prepare_new_deferred_fragment_record( self, label: str | None, path: Path | None, - parent_context: IncrementalDataRecord | None, + parent_context: IncrementalDataRecord, ) -> DeferredFragmentRecord: """Prepare a new deferred fragment record.""" - deferred_fragment_record = DeferredFragmentRecord(label, path, parent_context) + deferred_fragment_record = DeferredFragmentRecord(label, path) - context = parent_context or self._initial_result - context.children[deferred_fragment_record] = None + parent_context.children[deferred_fragment_record] = None return deferred_fragment_record def prepare_new_stream_items_record( self, label: str | None, path: Path | None, - parent_context: IncrementalDataRecord | None, + parent_context: IncrementalDataRecord, async_iterator: AsyncIterator[Any] | None = None, ) -> StreamItemsRecord: """Prepare a new stream items record.""" - stream_items_record = StreamItemsRecord( - label, path, parent_context, async_iterator - ) + stream_items_record = StreamItemsRecord(label, path, async_iterator) - context = parent_context or self._initial_result - context.children[stream_items_record] = None + parent_context.children[stream_items_record] = None return stream_items_record def complete_deferred_fragment_record( @@ -481,29 +471,34 @@ def add_field_error( """Add a field error to the given incremental data record.""" incremental_data_record.errors.append(error) - def publish_initial(self) -> None: + def publish_initial(self, initial_result: InitialResultRecord) -> None: """Publish the initial result.""" - for child in self._initial_result.children: + for child in initial_result.children: + if child.filtered: + continue self._publish(child) + def get_initial_errors( + self, initial_result: InitialResultRecord + ) -> list[GraphQLError]: + """Get the errors from the given initial result.""" + return initial_result.errors + def filter( self, null_path: Path, - erroring_incremental_data_record: IncrementalDataRecord | None, + erroring_incremental_data_record: IncrementalDataRecord, ) -> None: """Filter out the given erroring incremental data record.""" null_path_list = null_path.as_list() - children = (erroring_incremental_data_record or self._initial_result).children + descendants = self._get_descendants(erroring_incremental_data_record.children) - for child in self._get_descendants(children): + for child in descendants: if not self._matches_path(child.path, null_path_list): continue - self._delete(child) - parent = child.parent_context or self._initial_result - with suppress_key_error: - del parent.children[child] + child.filtered = True if isinstance(child, StreamItemsRecord): async_iterator = child.async_iterator @@ -522,32 +517,24 @@ def _trigger(self) -> None: resolve.set() self._resolve = Event() - def _introduce(self, item: IncrementalDataRecord) -> None: + def _introduce(self, item: SubsequentDataRecord) -> None: """Introduce a new IncrementalDataRecord.""" self._pending[item] = None - def _release(self, item: IncrementalDataRecord) -> None: + def _release(self, item: SubsequentDataRecord) -> None: """Release the given IncrementalDataRecord.""" if item in self._pending: self._released[item] = None self._trigger() - def _push(self, item: IncrementalDataRecord) -> None: + def _push(self, item: SubsequentDataRecord) -> None: """Push the given IncrementalDataRecord.""" self._released[item] = None self._pending[item] = None self._trigger() - def _delete(self, item: IncrementalDataRecord) -> None: - """Delete the given IncrementalDataRecord.""" - with suppress_key_error: - del self._released[item] - with suppress_key_error: - del self._pending[item] - self._trigger() - def _get_incremental_result( - self, completed_records: Collection[IncrementalDataRecord] + self, completed_records: Collection[SubsequentDataRecord] ) -> SubsequentIncrementalExecutionResult | None: """Get the incremental result with the completed records.""" incremental_results: list[IncrementalResult] = [] @@ -556,6 +543,8 @@ def _get_incremental_result( for incremental_data_record in completed_records: incremental_result: IncrementalResult for child in incremental_data_record.children: + if child.filtered: + continue self._publish(child) if isinstance(incremental_data_record, StreamItemsRecord): items = incremental_data_record.items @@ -591,18 +580,18 @@ def _get_incremental_result( return SubsequentIncrementalExecutionResult(has_next=False) return None - def _publish(self, incremental_data_record: IncrementalDataRecord) -> None: + def _publish(self, subsequent_result_record: SubsequentDataRecord) -> None: """Publish the given incremental data record.""" - if incremental_data_record.is_completed: - self._push(incremental_data_record) + if subsequent_result_record.is_completed: + self._push(subsequent_result_record) else: - self._introduce(incremental_data_record) + self._introduce(subsequent_result_record) def _get_descendants( self, - children: dict[IncrementalDataRecord, None], - descendants: dict[IncrementalDataRecord, None] | None = None, - ) -> dict[IncrementalDataRecord, None]: + children: dict[SubsequentDataRecord, None], + descendants: dict[SubsequentDataRecord, None] | None = None, + ) -> dict[SubsequentDataRecord, None]: """Get the descendants of the given children.""" if descendants is None: descendants = {} @@ -625,6 +614,13 @@ def _add_task(self, awaitable: Awaitable[Any]) -> None: task.add_done_callback(tasks.discard) +class InitialResultRecord(NamedTuple): + """Formatted subsequent incremental execution result""" + + errors: list[GraphQLError] + children: dict[SubsequentDataRecord, None] + + class DeferredFragmentRecord: """A record collecting data marked with the defer directive""" @@ -632,22 +628,16 @@ class DeferredFragmentRecord: label: str | None path: list[str | int] data: dict[str, Any] | None - parent_context: IncrementalDataRecord | None - children: dict[IncrementalDataRecord, None] + children: dict[SubsequentDataRecord, None] is_completed: bool + filtered: bool - def __init__( - self, - label: str | None, - path: Path | None, - parent_context: IncrementalDataRecord | None, - ) -> None: + def __init__(self, label: str | None, path: Path | None) -> None: self.label = label self.path = path.as_list() if path else [] - self.parent_context = parent_context self.errors = [] self.children = {} - self.is_completed = False + self.is_completed = self.filtered = False self.data = None def __repr__(self) -> str: @@ -655,8 +645,6 @@ def __repr__(self) -> str: args: list[str] = [f"path={self.path!r}"] if self.label: args.append(f"label={self.label!r}") - if self.parent_context: - args.append("parent_context") if self.data is not None: args.append("data") return f"{name}({', '.join(args)})" @@ -669,26 +657,24 @@ class StreamItemsRecord: label: str | None path: list[str | int] items: list[str] | None - parent_context: IncrementalDataRecord | None - children: dict[IncrementalDataRecord, None] + children: dict[SubsequentDataRecord, None] async_iterator: AsyncIterator[Any] | None is_completed_async_iterator: bool is_completed: bool + filtered: bool def __init__( self, label: str | None, path: Path | None, - parent_context: IncrementalDataRecord | None, async_iterator: AsyncIterator[Any] | None = None, ) -> None: self.label = label self.path = path.as_list() if path else [] - self.parent_context = parent_context self.async_iterator = async_iterator self.errors = [] self.children = {} - self.is_completed_async_iterator = self.is_completed = False + self.is_completed_async_iterator = self.is_completed = self.filtered = False self.items = None def __repr__(self) -> str: @@ -696,11 +682,11 @@ def __repr__(self) -> str: args: list[str] = [f"path={self.path!r}"] if self.label: args.append(f"label={self.label!r}") - if self.parent_context: - args.append("parent_context") if self.items is not None: args.append("items") return f"{name}({', '.join(args)})" -IncrementalDataRecord = Union[DeferredFragmentRecord, StreamItemsRecord] +SubsequentDataRecord = Union[DeferredFragmentRecord, StreamItemsRecord] + +IncrementalDataRecord = Union[InitialResultRecord, SubsequentDataRecord] diff --git a/tests/execution/test_defer.py b/tests/execution/test_defer.py index 312a2a0b..41161248 100644 --- a/tests/execution/test_defer.py +++ b/tests/execution/test_defer.py @@ -321,17 +321,13 @@ def can_compare_subsequent_incremental_execution_result(): } def can_print_deferred_fragment_record(): - record = DeferredFragmentRecord(None, None, None) + record = DeferredFragmentRecord(None, None) assert str(record) == "DeferredFragmentRecord(path=[])" - record = DeferredFragmentRecord("foo", Path(None, "bar", "Bar"), record) - assert ( - str(record) == "DeferredFragmentRecord(" - "path=['bar'], label='foo', parent_context)" - ) + record = DeferredFragmentRecord("foo", Path(None, "bar", "Bar")) + assert str(record) == "DeferredFragmentRecord(" "path=['bar'], label='foo')" record.data = {"hello": "world"} assert ( - str(record) == "DeferredFragmentRecord(" - "path=['bar'], label='foo', parent_context, data)" + str(record) == "DeferredFragmentRecord(" "path=['bar'], label='foo', data)" ) @pytest.mark.asyncio diff --git a/tests/execution/test_stream.py b/tests/execution/test_stream.py index 8a1ca605..42188517 100644 --- a/tests/execution/test_stream.py +++ b/tests/execution/test_stream.py @@ -173,18 +173,12 @@ def can_format_and_print_incremental_stream_result(): ) def can_print_stream_record(): - record = StreamItemsRecord(None, None, None, None) + record = StreamItemsRecord(None, None, None) assert str(record) == "StreamItemsRecord(path=[])" - record = StreamItemsRecord("foo", Path(None, "bar", "Bar"), record, None) - assert ( - str(record) == "StreamItemsRecord(" - "path=['bar'], label='foo', parent_context)" - ) + record = StreamItemsRecord("foo", Path(None, "bar", "Bar"), None) + assert str(record) == "StreamItemsRecord(" "path=['bar'], label='foo')" record.items = ["hello", "world"] - assert ( - str(record) == "StreamItemsRecord(" - "path=['bar'], label='foo', parent_context, items)" - ) + assert str(record) == "StreamItemsRecord(" "path=['bar'], label='foo', items)" # noinspection PyTypeChecker def can_compare_incremental_stream_result():