Skip to content

Commit

Permalink
Merge pull request #9 from TranslatorSRI/async_redis
Browse files Browse the repository at this point in the history
async redis
  • Loading branch information
cbizon authored Dec 20, 2023
2 parents 1e8b158 + c38348a commit 3f878ed
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 35 deletions.
51 changes: 38 additions & 13 deletions src/descender.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,17 @@ def __init__(self,rc = None):
When we load from redis, we also pull in the s and o partial patterns which are used to filter at q time.
If you are creating descender with an rc, you also need to call setup on it asynchronously."""
if rc is not None:
db = rc.r[7]
self.pq_to_descendants = jsonpickle.decode(db.get("pq_to_descendants"))
self.type_to_descendants = jsonpickle.decode(db.get("type_to_descendants"))
self.predicate_is_symmetric = jsonpickle.decode(db.get("predicate_symmetries"))
self.s_partial_patterns = jsonpickle.decode(db.get("s_partial_patterns"))
self.o_partial_patterns = jsonpickle.decode(db.get("o_partial_patterns"))
#db = rc.r[7]
#self.pq_to_descendants = jsonpickle.decode(db.get("pq_to_descendants"))
self.pq_to_descendants = None
#self.type_to_descendants = jsonpickle.decode(db.get("type_to_descendants"))
self.type_to_descendants = None
#self.predicate_is_symmetric = jsonpickle.decode(db.get("predicate_symmetries"))
self.predicate_is_symmetric = None
#self.s_partial_patterns = jsonpickle.decode(db.get("s_partial_patterns"))
self.s_partial_patterns = None
#self.o_partial_patterns = jsonpickle.decode(db.get("o_partial_patterns"))
self.o_partial_patterns = None
self.pq_to_descendant_int_ids = None
#Need to hang onto this b/c we are going to lazy load pq_to_descendant_int_ids. Doing it here is a pain from an async perspective
self.rc = rc
Expand All @@ -29,7 +34,19 @@ def __init__(self,rc = None):
self.predicate_is_symmetric = self.create_is_symmetric()
self.deeptypescache = {}

def is_symmetric(self, predicate):
async def get_s_partial_patterns(self):
if self.s_partial_patterns is None:
self.s_partial_patterns = jsonpickle.decode(await self.rc.r[7].get("s_partial_patterns"))
return self.s_partial_patterns

async def get_o_partial_patterns(self):
if self.o_partial_patterns is None:
self.o_partial_patterns = jsonpickle.decode(await self.rc.r[7].get("o_partial_patterns"))
return self.o_partial_patterns

async def is_symmetric(self, predicate):
if self.predicate_is_symmetric is None:
self.predicate_is_symmetric = jsonpickle.decode(await self.rc.r[7].get("predicate_symmetries"))
return self.predicate_is_symmetric[predicate]
def create_is_symmetric(self):
# Create a dictionary from predicate to whether it is symmetric
Expand Down Expand Up @@ -101,16 +118,24 @@ def create_pq_to_descendants(self):
if original_pk in v:
pq_to_descendants[k].update(decs[original_pk])
return pq_to_descendants
def get_type_descendants(self, t):
async def get_type_descendants(self, t):
if self.type_to_descendants is None:
self.type_to_descendants = jsonpickle.decode(await self.rc.r[7].get("type_to_descendants"))
return self.type_to_descendants[t]
def get_pq_descendants(self, pq):
try:
return self.pq_to_descendants[pq]
except:
return [pq]
#async def get_pq_descendants(self, pq):
# try:
# if self.pq_to_descendants is None:
# self.pq_to_descendants = await self.create_pq_to_descendants(self.rc)
# return self.pq_to_descendants[pq]
# except:
# return [pq]
async def create_pq_to_descendant_int_ids(self,rc):
# Create a dictionary from pq to all of its descendant integer ids
# First, pull the integer id for every pq
# Lazy create pq_to_descendants by puling it from redis
if self.pq_to_descendants is None:
pkl = await self.rc.r[7].get("pq_to_descendants")
self.pq_to_descendants = jsonpickle.decode(pkl)
pql = list(self.pq_to_descendants.keys())
pq_int_ids = await rc.pipeline_gets(3, pql, True)
# now convert pq_to_descendants into int id values
Expand Down
14 changes: 7 additions & 7 deletions src/query_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,15 @@ async def gquery(input_curies, pq, output_type, input_is_subject, descender, rc,
for type_int_id in type_int_ids:
for pq_int_id in pq_int_ids:
#Filter to the ones that are actually in the db
if f"{pq_int_id},{type_int_id}" in descender.s_partial_patterns:
if f"{pq_int_id},{type_int_id}" in await descender.get_s_partial_patterns():
for iid in input_int_ids:
query_patterns.append(create_query_pattern(iid, pq_int_id, type_int_id))
iid_list.append(iid)
else:
for type_int_id in type_int_ids:
for pq_int_id in pq_int_ids:
#Filter to the ones that are actually in the db
if f"{type_int_id},-{pq_int_id}" in descender.o_partial_patterns:
if f"{type_int_id},-{pq_int_id}" in await descender.get_o_partial_patterns():
for iid in input_int_ids:
query_patterns.append(create_query_pattern(type_int_id, -pq_int_id, iid) )
iid_list.append(iid)
Expand Down Expand Up @@ -105,12 +105,12 @@ async def gquery(input_curies, pq, output_type, input_is_subject, descender, rc,
async def get_results_for_query_patterns(pipelines, query_patterns):
for qp in query_patterns:
pipelines[5].lrange(qp, 0, -1)
results = pipelines[5].execute()
results = await pipelines[5].execute()
return results


async def get_type_int_ids(descender, output_type, rc):
output_types = descender.get_type_descendants(output_type)
output_types = await descender.get_type_descendants(output_type)
res = await rc.pipeline_gets(2, output_types, True)
type_int_ids = res.values()
return type_int_ids
Expand All @@ -119,9 +119,9 @@ async def get_type_int_ids(descender, output_type, rc):


async def get_strings(input_int_ids, output_node_ids, edge_ids,rc):
input_node_strings = rc.r[1].mget(set(input_int_ids))
output_node_strings = rc.r[1].mget(set(output_node_ids))
input_node_strings = await rc.r[1].mget(set(input_int_ids))
output_node_strings = await rc.r[1].mget(set(output_node_ids))

edge_strings = rc.r[4].mget(edge_ids)
edge_strings = await rc.r[4].mget(edge_ids)

return input_node_strings, output_node_strings, edge_strings
19 changes: 7 additions & 12 deletions src/redis_connector.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,20 @@
import redis
import redis.asyncio as redis

class RedisConnection:
# RedisConnection is a class that holds a connection to a redis database
# it is a context manager and can be used in a with statement
def __init__(self,host,port,password):
self.r = []
self.r.append(redis.StrictRedis(host=host, port=port, db=0, password=password))
self.r.append(redis.StrictRedis(host=host, port=port, db=1, password=password))
self.r.append(redis.StrictRedis(host=host, port=port, db=2, password=password))
self.r.append(redis.StrictRedis(host=host, port=port, db=3, password=password))
self.r.append(redis.StrictRedis(host=host, port=port, db=4, password=password))
self.r.append(redis.StrictRedis(host=host, port=port, db=5, password=password))
self.r.append(redis.StrictRedis(host=host, port=port, db=6, password=password))
self.r.append(redis.StrictRedis(host=host, port=port, db=7, password=password))
for i in range(8):
self.r.append(redis.StrictRedis(host=host, port=port, db=i, password=password, socket_connect_timeout=600))
self.p = [ rc.pipeline() for rc in self.r ]
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
for p in self.p:
p.execute()
for rc in self.r:
rc.close()
rc.aclose()
def get_pipelines(self):
return self.p
def flush_pipelines(self):
Expand All @@ -32,7 +27,7 @@ async def pipeline_gets(self, pipeline_id, keys, convert_to_int=True):
for key in keys:
pipe = self.p[pipeline_id]
pipe.get(key)
values = self.p[pipeline_id].execute()
values = await self.p[pipeline_id].execute()
if convert_to_int:
s = {k:int(v) for k,v in zip(keys, values) if v is not None}
return s
Expand All @@ -47,7 +42,7 @@ async def get_int_node_ids(self, input_curies):
# Now, extend the input_int_ids with the subclass ids
for iid in input_int_ids:
self.p[6].lrange(iid, 0, -1)
results = self.p[6].execute()
results = await self.p[6].execute()
subclass_int_ids = [int(item) for sublist in results for item in sublist]
input_int_ids.extend(subclass_int_ids)
return input_int_ids
6 changes: 3 additions & 3 deletions src/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,17 +106,17 @@ async def query_handler(request: PDResponse):
object_curies = object_node["ids"]
input_nodes, output_nodes, edges = await bquery(subject_curies, pq, object_curies, descender, rc)
# TODO: this is an opportunity for speedup because there is some duplicated work here.
if descender.is_symmetric(q_pred):
if await descender.is_symmetric(q_pred):
output_nodes_r, input_nodes_r, edges_r = await bquery(object_curies, pq, subject_curies, descender, rc)
elif "ids" in subject_node:
subject_curies = subject_node["ids"]
input_nodes, output_nodes, edges = await oquery(subject_curies, pq, object_node["categories"][0], descender, rc)
if descender.is_symmetric(q_pred):
if await descender.is_symmetric(q_pred):
output_nodes_r, input_nodes_r, edges_r = await squery(subject_curies, pq, object_node["categories"][0], descender, rc)
else:
object_curies = object_node["ids"]
input_nodes, output_nodes, edges = await squery(object_curies, pq, subject_node["categories"][0], descender, rc)
if descender.is_symmetric(q_pred):
if await descender.is_symmetric(q_pred):
output_nodes_r, input_nodes_r, edges_r = await oquery(object_curies, pq, subject_node["categories"][0], descender, rc)

# Merge the results, but we need to worry about duplicating nodes. The edges are by construction
Expand Down

0 comments on commit 3f878ed

Please sign in to comment.