Skip to content

Commit

Permalink
resolve conflict (#10)
Browse files Browse the repository at this point in the history
Co-authored-by: dvora-h <[email protected]>
  • Loading branch information
noah-chae and dvora-h authored Jul 30, 2023
1 parent e85e3f7 commit 1f75b91
Show file tree
Hide file tree
Showing 6 changed files with 559 additions and 37 deletions.
83 changes: 73 additions & 10 deletions redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,7 @@ class AbstractRedis:
"QUIT": bool_ok,
"STRALGO": parse_stralgo,
"PUBSUB NUMSUB": parse_pubsub_numsub,
"PUBSUB SHARDNUMSUB": parse_pubsub_numsub,
"RANDOMKEY": lambda r: r and r or None,
"RESET": str_if_bytes,
"SCAN": parse_scan,
Expand Down Expand Up @@ -1365,8 +1366,8 @@ class PubSub:
will be returned and it's safe to start listening again.
"""

PUBLISH_MESSAGE_TYPES = ("message", "pmessage")
UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe")
PUBLISH_MESSAGE_TYPES = ("message", "pmessage", "smessage")
UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe", "sunsubscribe")
HEALTH_CHECK_MESSAGE = "redis-py-health-check"

def __init__(
Expand Down Expand Up @@ -1414,9 +1415,11 @@ def reset(self):
self.connection.clear_connect_callbacks()
self.connection_pool.release(self.connection)
self.connection = None
self.channels = {}
self.health_check_response_counter = 0
self.channels = {}
self.pending_unsubscribe_channels = set()
self.shard_channels = {}
self.pending_unsubscribe_shard_channels = set()
self.patterns = {}
self.pending_unsubscribe_patterns = set()
self.subscribed_event.clear()
Expand All @@ -1431,16 +1434,23 @@ def on_connect(self, connection):
# before passing them to [p]subscribe.
self.pending_unsubscribe_channels.clear()
self.pending_unsubscribe_patterns.clear()
self.pending_unsubscribe_shard_channels.clear()
if self.channels:
channels = {}
for k, v in self.channels.items():
channels[self.encoder.decode(k, force=True)] = v
channels = {
self.encoder.decode(k, force=True): v for k, v in self.channels.items()
}
self.subscribe(**channels)
if self.patterns:
patterns = {}
for k, v in self.patterns.items():
patterns[self.encoder.decode(k, force=True)] = v
patterns = {
self.encoder.decode(k, force=True): v for k, v in self.patterns.items()
}
self.psubscribe(**patterns)
if self.shard_channels:
shard_channels = {
self.encoder.decode(k, force=True): v
for k, v in self.shard_channels.items()
}
self.ssubscribe(**shard_channels)

@property
def subscribed(self):
Expand Down Expand Up @@ -1647,6 +1657,45 @@ def unsubscribe(self, *args):
self.pending_unsubscribe_channels.update(channels)
return self.execute_command("UNSUBSCRIBE", *args)

def ssubscribe(self, *args, target_node=None, **kwargs):
"""
Subscribes the client to the specified shard channels.
Channels supplied as keyword arguments expect a channel name as the key
and a callable as the value. A channel's callable will be invoked automatically
when a message is received on that channel rather than producing a message via
``listen()`` or ``get_sharded_message()``.
"""
if args:
args = list_or_args(args[0], args[1:])
new_s_channels = dict.fromkeys(args)
new_s_channels.update(kwargs)
ret_val = self.execute_command("SSUBSCRIBE", *new_s_channels.keys())
# update the s_channels dict AFTER we send the command. we don't want to
# subscribe twice to these channels, once for the command and again
# for the reconnection.
new_s_channels = self._normalize_keys(new_s_channels)
self.shard_channels.update(new_s_channels)
if not self.subscribed:
# Set the subscribed_event flag to True
self.subscribed_event.set()
# Clear the health check counter
self.health_check_response_counter = 0
self.pending_unsubscribe_shard_channels.difference_update(new_s_channels)
return ret_val

def sunsubscribe(self, *args, target_node=None):
"""
Unsubscribe from the supplied shard_channels. If empty, unsubscribe from
all shard_channels
"""
if args:
args = list_or_args(args[0], args[1:])
s_channels = self._normalize_keys(dict.fromkeys(args))
else:
s_channels = self.shard_channels
self.pending_unsubscribe_shard_channels.update(s_channels)
return self.execute_command("SUNSUBSCRIBE", *args)

def listen(self):
"Listen for messages on channels this client has been subscribed to"
while self.subscribed:
Expand Down Expand Up @@ -1681,6 +1730,8 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0.0):
return self.handle_message(response, ignore_subscribe_messages)
return None

get_sharded_message = get_message

def ping(self, message=None):
"""
Ping the Redis server
Expand Down Expand Up @@ -1726,12 +1777,17 @@ def handle_message(self, response, ignore_subscribe_messages=False):
if pattern in self.pending_unsubscribe_patterns:
self.pending_unsubscribe_patterns.remove(pattern)
self.patterns.pop(pattern, None)
elif message_type == "sunsubscribe":
s_channel = response[1]
if s_channel in self.pending_unsubscribe_shard_channels:
self.pending_unsubscribe_shard_channels.remove(s_channel)
self.shard_channels.pop(s_channel, None)
else:
channel = response[1]
if channel in self.pending_unsubscribe_channels:
self.pending_unsubscribe_channels.remove(channel)
self.channels.pop(channel, None)
if not self.channels and not self.patterns:
if not self.channels and not self.patterns and not self.shard_channels:
# There are no subscriptions anymore, set subscribed_event flag
# to false
self.subscribed_event.clear()
Expand All @@ -1740,6 +1796,8 @@ def handle_message(self, response, ignore_subscribe_messages=False):
# if there's a message handler, invoke it
if message_type == "pmessage":
handler = self.patterns.get(message["pattern"], None)
elif message_type == "smessage":
handler = self.shard_channels.get(message["channel"], None)
else:
handler = self.channels.get(message["channel"], None)
if handler:
Expand All @@ -1760,6 +1818,11 @@ def run_in_thread(self, sleep_time=0, daemon=False, exception_handler=None):
for pattern, handler in self.patterns.items():
if handler is None:
raise PubSubError(f"Pattern: '{pattern}' has no handler registered")
for s_channel, handler in self.shard_channels.items():
if handler is None:
raise PubSubError(
f"Shard Channel: '{s_channel}' has no handler registered"
)

thread = PubSubWorkerThread(
self, sleep_time, daemon=daemon, exception_handler=exception_handler
Expand Down
105 changes: 101 additions & 4 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from redis.client import CaseInsensitiveDict, PubSub, Redis, parse_scan
from redis.commands import READ_COMMANDS, CommandsParser, RedisClusterCommands
from redis.connection import ConnectionPool, DefaultParser, Encoder, parse_url
from redis.commands.helpers import list_or_args
from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
from redis.exceptions import (
AskError,
Expand Down Expand Up @@ -219,6 +220,8 @@ class AbstractRedisCluster:
"PUBSUB CHANNELS",
"PUBSUB NUMPAT",
"PUBSUB NUMSUB",
"PUBSUB SHARDCHANNELS",
"PUBSUB SHARDNUMSUB",
"PING",
"INFO",
"SHUTDOWN",
Expand Down Expand Up @@ -343,11 +346,13 @@ class AbstractRedisCluster:
}

RESULT_CALLBACKS = dict_merge(
list_keys_to_dict(["PUBSUB NUMSUB"], parse_pubsub_numsub),
list_keys_to_dict(["PUBSUB NUMSUB", "PUBSUB SHARDNUMSUB"], parse_pubsub_numsub),
list_keys_to_dict(
["PUBSUB NUMPAT"], lambda command, res: sum(list(res.values()))
),
list_keys_to_dict(["KEYS", "PUBSUB CHANNELS"], merge_result),
list_keys_to_dict(
["KEYS", "PUBSUB CHANNELS", "PUBSUB SHARDCHANNELS"], merge_result
),
list_keys_to_dict(
[
"PING",
Expand Down Expand Up @@ -1655,6 +1660,8 @@ def __init__(self, redis_cluster, node=None, host=None, port=None, **kwargs):
else redis_cluster.get_redis_connection(self.node).connection_pool
)
self.cluster = redis_cluster
self.node_pubsub_mapping = {}
self._pubsubs_generator = self._pubsubs_generator()
super().__init__(
**kwargs, connection_pool=connection_pool, encoder=redis_cluster.encoder
)
Expand Down Expand Up @@ -1708,9 +1715,9 @@ def _raise_on_invalid_node(self, redis_cluster, node, host, port):
f"Node {host}:{port} doesn't exist in the cluster"
)

def execute_command(self, *args, **kwargs):
def execute_command(self, *args):
"""
Execute a publish/subscribe command.
Execute a subscribe/unsubscribe command.
Taken code from redis-py and tweak to make it work within a cluster.
"""
Expand Down Expand Up @@ -1743,13 +1750,103 @@ def execute_command(self, *args, **kwargs):
connection = self.connection
self._execute(connection, connection.send_command, *args)

def _get_node_pubsub(self, node):
try:
return self.node_pubsub_mapping[node.name]
except KeyError:
pubsub = node.redis_connection.pubsub()
self.node_pubsub_mapping[node.name] = pubsub
return pubsub

def _sharded_message_generator(self):
for _ in range(len(self.node_pubsub_mapping)):
pubsub = next(self._pubsubs_generator)
message = pubsub.get_message()
if message is not None:
return message
return None

def _pubsubs_generator(self):
while True:
for pubsub in self.node_pubsub_mapping.values():
yield pubsub

def get_sharded_message(
self, ignore_subscribe_messages=False, timeout=0.0, target_node=None
):
if target_node:
message = self.node_pubsub_mapping[target_node.name].get_message(
ignore_subscribe_messages=ignore_subscribe_messages, timeout=timeout
)
else:
message = self._sharded_message_generator()
if message is None:
return None
elif str_if_bytes(message["type"]) == "sunsubscribe":
if message["channel"] in self.pending_unsubscribe_shard_channels:
self.pending_unsubscribe_shard_channels.remove(message["channel"])
self.shard_channels.pop(message["channel"], None)
node = self.cluster.get_node_from_key(message["channel"])
if self.node_pubsub_mapping[node.name].subscribed is False:
self.node_pubsub_mapping.pop(node.name)
if not self.channels and not self.patterns and not self.shard_channels:
# There are no subscriptions anymore, set subscribed_event flag
# to false
self.subscribed_event.clear()
if self.ignore_subscribe_messages or ignore_subscribe_messages:
return None
return message

def ssubscribe(self, *args, **kwargs):
if args:
args = list_or_args(args[0], args[1:])
s_channels = dict.fromkeys(args)
s_channels.update(kwargs)
for s_channel, handler in s_channels.items():
node = self.cluster.get_node_from_key(s_channel)
pubsub = self._get_node_pubsub(node)
if handler:
pubsub.ssubscribe(**{s_channel: handler})
else:
pubsub.ssubscribe(s_channel)
self.shard_channels.update(pubsub.shard_channels)
self.pending_unsubscribe_shard_channels.difference_update(
self._normalize_keys({s_channel: None})
)
if pubsub.subscribed and not self.subscribed:
self.subscribed_event.set()
self.health_check_response_counter = 0

def sunsubscribe(self, *args):
if args:
args = list_or_args(args[0], args[1:])
else:
args = self.shard_channels

for s_channel in args:
node = self.cluster.get_node_from_key(s_channel)
p = self._get_node_pubsub(node)
p.sunsubscribe(s_channel)
self.pending_unsubscribe_shard_channels.update(
p.pending_unsubscribe_shard_channels
)

def get_redis_connection(self):
"""
Get the Redis connection of the pubsub connected node.
"""
if self.node is not None:
return self.node.redis_connection

def disconnect(self):
"""
Disconnect the pubsub connection.
"""
if self.connection:
self.connection.disconnect()
for pubsub in self.node_pubsub_mapping.values():
pubsub.connection.disconnect()


class ClusterPipeline(RedisCluster):
"""
Expand Down
26 changes: 26 additions & 0 deletions redis/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5090,6 +5090,15 @@ def publish(self, channel: ChannelT, message: EncodableT, **kwargs) -> ResponseT
"""
return self.execute_command("PUBLISH", channel, message, **kwargs)

def spublish(self, shard_channel: ChannelT, message: EncodableT) -> ResponseT:
"""
Posts a message to the given shard channel.
Returns the number of clients that received the message
For more information see https://redis.io/commands/spublish
"""
return self.execute_command("SPUBLISH", shard_channel, message)

def pubsub_channels(self, pattern: PatternT = "*", **kwargs) -> ResponseT:
"""
Return a list of channels that have at least one subscriber
Expand All @@ -5098,6 +5107,14 @@ def pubsub_channels(self, pattern: PatternT = "*", **kwargs) -> ResponseT:
"""
return self.execute_command("PUBSUB CHANNELS", pattern, **kwargs)

def pubsub_shardchannels(self, pattern: PatternT = "*", **kwargs) -> ResponseT:
"""
Return a list of shard_channels that have at least one subscriber
For more information see https://redis.io/commands/pubsub-shardchannels
"""
return self.execute_command("PUBSUB SHARDCHANNELS", pattern, **kwargs)

def pubsub_numpat(self, **kwargs) -> ResponseT:
"""
Returns the number of subscriptions to patterns
Expand All @@ -5115,6 +5132,15 @@ def pubsub_numsub(self, *args: ChannelT, **kwargs) -> ResponseT:
"""
return self.execute_command("PUBSUB NUMSUB", *args, **kwargs)

def pubsub_shardnumsub(self, *args: ChannelT, **kwargs) -> ResponseT:
"""
Return a list of (shard_channel, number of subscribers) tuples
for each channel given in ``*args``
For more information see https://redis.io/commands/pubsub-shardnumsub
"""
return self.execute_command("PUBSUB SHARDNUMSUB", *args, **kwargs)


AsyncPubSubCommands = PubSubCommands

Expand Down
4 changes: 2 additions & 2 deletions redis/commands/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,13 @@ def _get_pubsub_keys(self, *args):
# the second argument is a part of the command name, e.g.
# ['PUBSUB', 'NUMSUB', 'foo'].
pubsub_type = args[1].upper()
if pubsub_type in ["CHANNELS", "NUMSUB"]:
if pubsub_type in ["CHANNELS", "NUMSUB", "SHARDCHANNELS", "SHARDNUMSUB"]:
keys = args[2:]
elif command in ["SUBSCRIBE", "PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE"]:
# format example:
# SUBSCRIBE channel [channel ...]
keys = list(args[1:])
elif command == "PUBLISH":
elif command in ["PUBLISH", "SPUBLISH"]:
# format example:
# PUBLISH channel message
keys = [args[1]]
Expand Down
Loading

0 comments on commit 1f75b91

Please sign in to comment.