Skip to content

Commit

Permalink
Improved shard connect and shard disconnect to reliably call the even…
Browse files Browse the repository at this point in the history
…t on time (#1744)
  • Loading branch information
davfsa authored Nov 2, 2023
1 parent 48d576d commit 49de2fe
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 12 deletions.
1 change: 1 addition & 0 deletions changes/1744.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Ensure shard connect and disconnect always get sent in pairs and properly waited for
2 changes: 1 addition & 1 deletion hikari/events/shard_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class ShardStateEvent(ShardEvent, abc.ABC):
@attrs_extensions.with_copy
@attrs.define(kw_only=True, weakref_slot=False)
class ShardConnectedEvent(ShardStateEvent):
"""Event fired when a shard connects."""
"""Event fired when a shard successfully connects."""

app: traits.RESTAware = attrs.field(metadata={attrs_extensions.SKIP_DEEP_COPY: True})
# <<inherited docstring from Event>>.
Expand Down
2 changes: 1 addition & 1 deletion hikari/impl/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ async def _request(

await aio.first_completed(request_task, self._close_event.wait())

if not self._close_event.is_set():
if not request_task.cancelled():
return request_task.result()

raise errors.ComponentStateConflictError("The REST client was closed mid-request")
Expand Down
6 changes: 4 additions & 2 deletions hikari/impl/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,6 @@ async def _connect(self) -> typing.Tuple[asyncio.Task[None], ...]:
dumps=self._dumps,
url=url,
)
self._event_manager.dispatch(self._event_factory.deserialize_connected_event(self))

# Expect initial HELLO
hello_payload = await self._ws.receive_json()
Expand Down Expand Up @@ -893,6 +892,7 @@ async def _keep_alive(self) -> None:
if not self._handshake_event.is_set():
continue

await self._event_manager.dispatch(self._event_factory.deserialize_connected_event(self))
await aio.first_completed(*lifetime_tasks)

# Since nothing went wrong, we can reset the backoff and try again
Expand Down Expand Up @@ -957,7 +957,9 @@ async def _keep_alive(self) -> None:
else:
await ws.send_close(code=_RESUME_CLOSE_CODE, message=b"shard disconnecting temporarily")

self._event_manager.dispatch(self._event_factory.deserialize_disconnected_event(self))
if self._handshake_event.is_set():
# We dispatched the connected event, so we can dispatch the disconnected one too
await self._event_manager.dispatch(self._event_factory.deserialize_disconnected_event(self))

def _serialize_and_store_presence_payload(
self,
Expand Down
8 changes: 0 additions & 8 deletions tests/hikari/impl/test_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,10 +1014,6 @@ async def test__connect_when_not_reconnecting(self, client, http_settings, proxy
dumps=client._dumps,
url="wss://somewhere.com?somewhere=true&v=400&encoding=json",
)
client._event_factory.deserialize_connected_event.assert_called_once_with(client)
client._event_manager.dispatch.assert_called_once_with(
client._event_factory.deserialize_connected_event.return_value
)

assert create_task.call_count == 2
create_task.assert_has_calls(
Expand Down Expand Up @@ -1103,10 +1099,6 @@ async def test__connect_when_reconnecting(self, client, http_settings, proxy_set
transport_compression=True,
url="wss://notsomewhere.com?somewhere=true&v=400&encoding=json&compress=zlib-stream",
)
client._event_factory.deserialize_connected_event.assert_called_once_with(client)
client._event_manager.dispatch.assert_called_once_with(
client._event_factory.deserialize_connected_event.return_value
)

assert create_task.call_count == 2
create_task.assert_has_calls(
Expand Down

0 comments on commit 49de2fe

Please sign in to comment.