Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use defaultdicts for BaseHubConnection handlers instead of lists #50

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 13 additions & 30 deletions signalrcore/hub/base_hub_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from signalrcore.helpers import Helpers
from .handlers import StreamHandler, InvocationHandler
from ..protocol.messagepack_protocol import MessagePackHubProtocol
from collections import defaultdict


class BaseHubConnection(object):
Expand All @@ -37,8 +38,8 @@ def __init__(
self.token = None # auth
self.state = ConnectionState.disconnected
self.connection_alive = False
self.handlers = []
self.stream_handlers = []
self.handlers = defaultdict(list)
self.stream_handlers = defaultdict(list)
self._thread = None
self._ws = None
self.verify_ssl = verify_ssl
Expand Down Expand Up @@ -114,7 +115,7 @@ def stop(self):

def register_handler(self, event, callback):
self.logger.debug("Handler registered started {0}".format(event))
self.handlers.append((event, callback))
self.handlers[event].append(callback)

def evaluate_handshake(self, message):
self.logger.debug("Evaluating handshake {0}".format(message))
Expand Down Expand Up @@ -180,15 +181,12 @@ def on_message(self, raw_message):
continue

if message.type == MessageType.invocation:
fired_handlers = list(
filter(
lambda h: h[0] == message.target,
self.handlers))
fired_handlers = self.handlers[message.target]
if len(fired_handlers) == 0:
self.logger.warning(
"event '{0}' hasn't fire any handler".format(
message.target))
for _, handler in fired_handlers:
for handler in fired_handlers:
handler(message.arguments)

if message.type == MessageType.close:
Expand All @@ -201,26 +199,17 @@ def on_message(self, raw_message):
self.on_error(message)

# Send callbacks
fired_handlers = list(
filter(
lambda h: h.invocation_id == message.invocation_id,
self.stream_handlers))
fired_handlers = self.stream_handlers[message.invocation_id]

# Stream callbacks
for handler in fired_handlers:
handler.complete_callback(message)

# unregister handler
self.stream_handlers = list(
filter(
lambda h: h.invocation_id != message.invocation_id,
self.stream_handlers))
self.stream_handlers.pop(message.invocation_id)

if message.type == MessageType.stream_item:
fired_handlers = list(
filter(
lambda h: h.invocation_id == message.invocation_id,
self.stream_handlers))
fired_handlers = self.stream_handlers[message.invocation_id]
if len(fired_handlers) == 0:
self.logger.warning(
"id '{0}' hasn't fire any stream handler".format(
Expand All @@ -232,10 +221,7 @@ def on_message(self, raw_message):
pass

if message.type == MessageType.cancel_invocation:
fired_handlers = list(
filter(
lambda h: h.invocation_id == message.invocation_id,
self.stream_handlers))
fired_handlers = self.stream_handlers[message.invocation_id]
if len(fired_handlers) == 0:
self.logger.warning(
"id '{0}' hasn't fire any stream handler".format(
Expand All @@ -245,16 +231,13 @@ def on_message(self, raw_message):
handler.error_callback(message)

# unregister handler
self.stream_handlers = list(
filter(
lambda h: h.invocation_id != message.invocation_id,
self.stream_handlers))
self.stream_handlers.pop(message.invocation_id)

def send(self, message, on_invocation = None):
self.logger.debug("Sending message {0}".format(message))
try:
if on_invocation:
self.stream_handlers.append(InvocationHandler(message.invocation_id, on_invocation))
self.stream_handlers[message.invocation_id].append(InvocationHandler(message.invocation_id, on_invocation))
self._ws.send(self.protocol.encode(message), opcode=0x2 if type(self.protocol) == MessagePackHubProtocol else 0x1)
self.connection_checker.last_message = time.time()
if self.reconnection_handler is not None:
Expand Down Expand Up @@ -300,7 +283,7 @@ def deferred_reconnect(self, sleep_time):
def stream(self, event, event_params):
invocation_id = str(uuid.uuid4())
stream_obj = StreamHandler(event, invocation_id)
self.stream_handlers.append(stream_obj)
self.stream_handlers[invocation_id].append(stream_obj)
self.send(
StreamInvocationMessage(
invocation_id,
Expand Down