Skip to content

Commit

Permalink
Fix issue of accumulative data source in query handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
khaledk2 committed Oct 31, 2024
1 parent 7efd2b5 commit fd67f25
Showing 1 changed file with 40 additions and 17 deletions.
57 changes: 40 additions & 17 deletions omero_search_engine/api/v1/resources/query_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,27 @@
res_or_main_attributes = None


def check_get_names(idr_, resource, attribute, return_exact=False):
def reset_global_values():
global res_and_main_attributes, res_or_main_attributes
res_and_main_attributes = None
res_or_main_attributes = None


def check_get_names(idr_, resource, attribute, data_source, return_exact=False):
# check the idr name and return the resource and possible values
if idr_:
idr_ = idr_.strip()
pr_names = get_resource_names(resource)
all_act_names = []
if pr_names:
if not return_exact:
for data_source, pr_names_ in pr_names.items():
for data_source_, pr_names_ in pr_names.items():
if (
data_source
and data_source != "all"
and data_source.lower() != data_source_.lower()
):
continue
act_name = [
name["id"]
for name in pr_names_
Expand All @@ -57,7 +69,13 @@ def check_get_names(idr_, resource, attribute, return_exact=False):
all_act_names = all_act_names + act_name
else:
# This should be modified to query specific data source specific
for data_source, pr_names_ in pr_names.items():
for data_source_, pr_names_ in pr_names.items():
if (
data_source
and data_source != "all"
and data_source.lower() != data_source_.lower()
):
continue
act_name = [
name["id"]
for name in pr_names_
Expand All @@ -68,7 +86,7 @@ def check_get_names(idr_, resource, attribute, return_exact=False):


class QueryItem(object):
def __init__(self, filter, adjust_res=True):
def __init__(self, filter, data_source, adjust_res=True):
"""
define query and adjust resource if it is needed,
e.g. name is provided
Expand All @@ -82,6 +100,7 @@ def __init__(self, filter, adjust_res=True):
self.name = filter.get("name")
self.value = filter.get("value")
self.operator = filter.get("operator")
self.data_source = data_source
if filter.get("set_query_type") and filter.get("query_type"):
self.query_type = filter.get("query_type")
else:
Expand All @@ -100,7 +119,9 @@ def adjust_resource(self):
if mapping_names[self.resource].get(self.name):
self.name = mapping_names[self.resource].get(self.name)
if self.operator == "contains" or self.operator == "not_contains":
ac_value = check_get_names(self.value, self.resource, self.name)
ac_value = check_get_names(
self.value, self.resource, self.name, self.data_source
)
if len(ac_value) == 0:
self.value = -1
elif len(ac_value) == 1:
Expand All @@ -113,7 +134,7 @@ def adjust_resource(self):
self.operator = "not_equals"
else:
ac_value = check_get_names(
self.value, self.resource, self.name, True
self.value, self.resource, self.name, self.data_source, True
)
if ac_value and len(ac_value) == 1:
self.value = ac_value[0]
Expand Down Expand Up @@ -229,7 +250,7 @@ def get_image_non_image_query(self):
query = {}
query["main_attribute"] = main
res = self.run_query(query, resource)
new_cond = get_ids(res, resource)
new_cond = get_ids(res, resource, self.data_source)
if new_cond:
if not main_or_attribute_.get(resource):
main_or_attribute_[resource] = new_cond
Expand Down Expand Up @@ -264,7 +285,7 @@ def get_image_non_image_query(self):
query = {}
query["or_filters"] = or_query
res = self.run_query(query, resource)
new_cond = get_ids(res, resource)
new_cond = get_ids(res, resource, self.data_source)
if new_cond:
if not main_or_attribute_.get(resource):
main_or_attribute_[resource] = new_cond
Expand Down Expand Up @@ -301,7 +322,7 @@ def get_image_non_image_query(self):
query = {}
query["main_attribute"] = main
res = self.run_query(query, resource)
new_cond = get_ids(res, resource)
new_cond = get_ids(res, resource, self.data_source)
if new_cond:
if not main_and_attribute.get(resource):
main_and_attribute[resource] = new_cond
Expand All @@ -323,7 +344,7 @@ def get_image_non_image_query(self):
query = {}
query["and_filters"] = and_query
res = self.run_query(query, resource)
new_cond = get_ids(res, resource)
new_cond = get_ids(res, resource, self.data_source)
if new_cond:
if not main_and_attribute.get(resource):
main_and_attribute[resource] = new_cond
Expand Down Expand Up @@ -351,7 +372,6 @@ def get_image_non_image_query(self):

def run_query(self, query_, resource):
main_attributes = {}

query = {"and_filters": [], "or_filters": []}

if query_.get("and_filters"):
Expand Down Expand Up @@ -527,7 +547,7 @@ def combine_conds(curnt_cond, new_cond, resource):
return returned_cond


def get_ids(results, resource):
def get_ids(results, resource, data_source):
ids = []
if results.get("results") and results.get("results").get("results"):
for item in results["results"]["results"]:
Expand All @@ -537,15 +557,15 @@ def get_ids(results, resource):
qur_item["value"] = item["id"]
qur_item["operator"] = "equals"
qur_item["resource"] = resource
qur_item_ = QueryItem(qur_item)
qur_item_ = QueryItem(qur_item, data_source)
ids.append(qur_item_)
else:
qur_item = {}
qur_item["name"] = "{resource}_id".format(resource=resource)
qur_item["value"] = -1
qur_item["operator"] = "equals"
qur_item["resource"] = resource
qur_item_ = QueryItem(qur_item)
qur_item_ = QueryItem(qur_item, data_source)
ids.append(qur_item_)
return ids

Expand Down Expand Up @@ -657,6 +677,8 @@ def determine_search_results_(
):
from omero_search_engine.api.v1.resources.utils import build_error_message

reset_global_values()

if query_.get("query_details"):
case_sensitive = query_.get("query_details").get("case_sensitive")
else:
Expand All @@ -678,7 +700,7 @@ def determine_search_results_(
if and_filters and len(and_filters) > 0:
and_query_group = QueryGroup("and_filters")
for filter in and_filters:
q_item = QueryItem(filter)
q_item = QueryItem(filter, data_source)
# Check the name value and, if it is a list,
# it will create a new or filter for them and move it
# Please note it is working for and filter when there is not
Expand Down Expand Up @@ -720,7 +742,7 @@ def determine_search_results_(
or_query_groups.append(or_query_group)
if isinstance(filters_, list):
for filter in filters_:
q_item = QueryItem(filter)
q_item = QueryItem(filter, data_source)
if q_item.query_type == "main_attribute" and (
filter["name"] == "description"
):
Expand All @@ -739,7 +761,7 @@ def determine_search_results_(
new_fil["operator"] = filter["operator"]
new_fil["set_query_type"] = True
new_fil["query_type"] = q_item.query_type
_q_item = QueryItem(new_fil)
_q_item = QueryItem(new_fil, data_source)
or_query_group.add_query(_q_item)
else:
q_item.name = "id"
Expand Down Expand Up @@ -776,6 +798,7 @@ def simple_search(
data_source,
return_containers=False,
):
reset_global_values()
if not operator:
operator = "equals"
if key:
Expand Down

0 comments on commit fd67f25

Please sign in to comment.