Skip to content

Commit

Permalink
ExtendedCollection.intersect can accept list and zset keys
Browse files Browse the repository at this point in the history
Before, only sets were allowed (without check, btw)
  • Loading branch information
twidi committed Jan 26, 2018
1 parent 7a4312e commit da4f2d0
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 26 deletions.
2 changes: 1 addition & 1 deletion doc/contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ Here is an example:
- a python list
- a python set
- a python tuple
- a string, which must be the key of a Redis_ set (cannot be a list of sorted set for now)
- a string, which must be the key of a Redis_ set, sorted_set or list (long operation if a list)
- a `limpyd` :ref:`SetField`, attached to a model
- a `limpyd` :ref:`ListField`, attached to a model
- a `limpyd` :ref:`SortedSetField`, attached to a model
Expand Down
69 changes: 45 additions & 24 deletions limpyd/contrib/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,17 @@ def _call_script(self, script_name, keys=[], args=[]):
script['script_object'] = conn.register_script(script['lua'])
return script['script_object'](keys=keys, args=args, client=conn)

def _list_to_set(self, list_field, set_key):
def _list_to_set(self, list_key, set_key):
"""
Store all content of the given ListField in a redis set.
Use scripting if available to avoid retrieving all values locally from
the list before sending them back to the set
"""
if self.cls.database.has_scripting():
self._call_script('list_to_set', keys=[list_field.key, set_key])
self._call_script('list_to_set', keys=[list_key, set_key])
else:
self.cls.get_connection().sadd(set_key, *list_field.lmembers())

conn = self.cls.get_connection()
conn.sadd(set_key, *conn.lrange(list_key, 0, -1))

@property
def _collection(self):
Expand All @@ -101,19 +101,47 @@ def _prepare_sets(self, sets):
As the new "intersect" method can accept different types of "set", we
have to handle them because we must return only keys of redis sets.
"""

if self.stored_key and not self.stored_key_exists():
raise DoesNotExist('This collection is based on a previous one, '
'stored at a key that does not exist anymore.')

conn = self.cls.get_connection()

all_sets = set()
tmp_keys = set()
only_one_set = len(sets) == 1

if self.stored_key and not self.stored_key_exists():
raise DoesNotExist('This collection is based on a previous one, '
'stored at a key that does not exist anymore.')
def add_key(key, key_type=None, is_tmp=False):
if not key_type:
key_type = conn.type(key)
if key_type == 'set':
all_sets.add(key)
elif key_type == 'zset':
all_sets.add(key)
self._has_sortedsets = True
elif key_type == 'list':
if only_one_set:
# we only have this list, use it directly
all_sets.add(key)

This comment has been minimized.

Copy link
@twidi

twidi Aug 28, 2019

Author Collaborator

we may have a problem here: what if something else is added later? We'll have the list in all_sets, and not converted as set. We may have to flag that we only have one list, and use this flag later to convert it if something else is added

else:
# many sets, convert the list to a simple redis set
tmp_key = self._unique_key()
self._list_to_set(key, tmp_key)
add_key(tmp_key, 'set', True)
elif key_type == 'none':
# considered as an empty set
all_sets.add(key)
else:
raise ValueError('Cannot use redis key %s of type %s for filtering' % (
key, key_type
))
if is_tmp:
tmp_keys.add(key)

for set_ in sets:
if isinstance(set_, str):
all_sets.add(set_)
add_key(set_)
elif isinstance(set_, ExtendedFilter):
# We have a RedisModel and we'll use its pk, or a RedisField
# (single value) and we'll use its value
Expand All @@ -126,31 +154,22 @@ def _prepare_sets(self, sets):
else:
raise ValueError(u'Invalide filter value for %s: %s' % (field_name, value))
key = field.index_key(val)
all_sets.add(key)
add_key(key)
elif isinstance(set_, SetField):
# Use the set key. If we need to intersect, we'll use
# sunionstore, and if not, store accepts set
all_sets.add(set_.key)
add_key(set_.key, 'set')
elif isinstance(set_, SortedSetField):
# Use the sorted set key. If we need to intersect, we'll use
# zinterstore, and if not, store accepts zset
all_sets.add(set_.key)
add_key(set_.key, 'zset')
elif isinstance(set_, (ListField, _StoredCollection)):
if only_one_set:
# we only have this list, use it directly
all_sets.add(set_.key)
else:
# many sets, convert the list to a simple redis set
tmp_key = self._unique_key()
self._list_to_set(set_, tmp_key)
tmp_keys.add(tmp_key)
all_sets.add(tmp_key)
add_key(set_.key, 'list')
elif isinstance(set_, tuple) and len(set_):
# if we got a list or set, create a redis set to hold its values
tmp_key = self._unique_key()
conn.sadd(tmp_key, *set_)
tmp_keys.add(tmp_key)
all_sets.add(tmp_key)
add_key(tmp_key, 'set', True)

return all_sets, tmp_keys

Expand All @@ -167,7 +186,8 @@ def intersect(self, *sets):
Each "set" represent a list of pk, the final goal is to return only pks
matching the intersection of all sets.
A "set" can be:
- a string: considered as a redis set's name
- a string: considered as the name of a redis set, sorted set or list
(if a list, values will be stored in a temporary set)
- a list, set or tuple: values will be stored in a temporary set
- a SetField: we will directly use it's content on redis
- a ListField or SortedSetField: values will be stored in a temporary
Expand All @@ -184,7 +204,8 @@ def intersect(self, *sets):
elif not isinstance(set_, (tuple, str, MultiValuesField, _StoredCollection)):
raise ValueError('%s is not a valid type of argument that can '
'be used as a set. Allowed are: string (key '
'of a redis set), limpyd multi-values field ('
'of a redis set, sorted set or list), '
'limpyd multi-values field ('
'SetField, ListField or SortedSetField), or '
'real python set, list or tuple' % set_)
if isinstance(set_, SortedSetField):
Expand Down
44 changes: 43 additions & 1 deletion tests/contrib/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def tearDown(self):
self.connection.sinterstore = IntersectTest.redis_sinterstore
super(IntersectTest, self).tearDown()

def test_intersect_should_accept_string(self):
def test_intersect_should_accept_set_key_as_string(self):
set_key = unique_key(self.connection)
self.connection.sadd(set_key, 1, 2)
collection = set(Group.collection().intersect(set_key))
Expand All @@ -260,6 +260,48 @@ def test_intersect_should_accept_string(self):
collection = set(Group.collection().intersect(set_key))
self.assertEqual(collection, set(['1', '2']))

def test_intersect_should_accept_sortedset_key_as_string(self):
zset_key = unique_key(self.connection)
self.connection.zadd(zset_key, 1.0, 1, 2.0, 2)
collection = set(Group.collection().intersect(zset_key))
self.assertEqual(self.last_interstore_call['command'], 'zinterstore')
self.assertEqual(collection, set(['1', '2']))

zset_key = unique_key(self.connection)
self.connection.zadd(zset_key, 1.0, 1, 2.0, 2, 10.0, 10, 50.0, 50)
collection = set(Group.collection().intersect(zset_key))
self.assertEqual(collection, set(['1', '2']))

def test_intersect_should_accept_list_key_as_string(self):
list_key = unique_key(self.connection)
self.connection.lpush(list_key, 1, 2)
collection = set(Group.collection().intersect(list_key))
self.assertEqual(self.last_interstore_call['command'], 'sinterstore')
self.assertEqual(collection, set(['1', '2']))

list_key = unique_key(self.connection)
self.connection.lpush(list_key, 1, 2, 10, 50)
collection = set(Group.collection().intersect(list_key))
self.assertEqual(collection, set(['1', '2']))

def test_intersect_should_not_accept_string_key_as_string(self):
str_key = unique_key(self.connection)
self.connection.set(str_key, 'foo')
with self.assertRaises(ValueError):
set(Group.collection().intersect(str_key))

def test_intersect_should_not_accept_hkey_key_as_string(self):
hash_key = unique_key(self.connection)
self.connection.hset(hash_key, 'foo', 'bar')
with self.assertRaises(ValueError):
set(Group.collection().intersect(hash_key))

def test_intersect_should_consider_non_existent_key_as_set(self):
no_key = unique_key(self.connection)
collection = set(Group.collection().intersect(no_key))
self.assertEqual(self.last_interstore_call['command'], 'sinterstore')
self.assertEqual(collection, set())

def test_intersect_should_accept_set(self):
collection = set(Group.collection().intersect(set([1, 2])))
self.assertEqual(self.last_interstore_call['command'], 'sinterstore')
Expand Down

0 comments on commit da4f2d0

Please sign in to comment.