Skip to content

Commit

Permalink
Allow class listeners (#1661)
Browse files Browse the repository at this point in the history
  • Loading branch information
davfsa authored Oct 28, 2023
1 parent 0ffe41c commit ffc8f11
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 3 deletions.
1 change: 1 addition & 0 deletions changes/1661.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow class listeners
11 changes: 8 additions & 3 deletions hikari/impl/event_manager_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,9 @@ def consume_raw_event(
def subscribe(
self, event_type: typing.Type[typing.Any], callback: event_manager_.CallbackT[typing.Any], *, _nested: int = 0
) -> None:
if not inspect.iscoroutinefunction(callback):
if not (
inspect.iscoroutinefunction(callback) or inspect.iscoroutinefunction(getattr(callback, "__call__", None))
):
raise TypeError("Cannot subscribe a non-coroutine function callback")

# `_nested` is used to show the correct source code snippet if an intent
Expand Down Expand Up @@ -620,8 +622,11 @@ async def _invoke_callback(
try:
await callback(event)
except Exception as ex:
# Skip the first frame in logs, we don't care for it.
trio = type(ex), ex, ex.__traceback__.tb_next if ex.__traceback__ is not None else None
# Skip the first frame in logs if it exists, as it means it wasn't our fault
trio: typing.Union[
typing.Tuple[typing.Type[Exception], Exception, typing.Optional[types.TracebackType]], Exception
]
trio = (type(ex), ex, ex.__traceback__.tb_next) if ex.__traceback__ else ex

if base_events.is_no_recursive_throw_event(event):
_LOGGER.error(
Expand Down
12 changes: 12 additions & 0 deletions tests/hikari/impl/test_event_manager_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,18 @@ async def test_handle_dispatch_invokes_when_consumer_not_enabled(self, event_man
consumer.callback.assert_not_called()
error_handler.assert_not_called()

def test_subscribe_when_class_call(self, event_manager):
class Foo:
async def __call__(self) -> None:
...

foo = Foo()
event_manager._check_event = mock.Mock()

event_manager.subscribe(member_events.MemberCreateEvent, foo)

assert event_manager._listeners[member_events.MemberCreateEvent] == [foo]

def test_subscribe_when_callback_is_not_coroutine(self, event_manager):
def test():
...
Expand Down

0 comments on commit ffc8f11

Please sign in to comment.