diff --git a/omero_search_engine/api/v1/resources/query_handler.py b/omero_search_engine/api/v1/resources/query_handler.py index ad57341..47ce2d8 100644 --- a/omero_search_engine/api/v1/resources/query_handler.py +++ b/omero_search_engine/api/v1/resources/query_handler.py @@ -39,7 +39,13 @@ res_or_main_attributes = None -def check_get_names(idr_, resource, attribute, return_exact=False): +def reset_global_values(): + global res_and_main_attributes, res_or_main_attributes + res_and_main_attributes = None + res_or_main_attributes = None + + +def check_get_names(idr_, resource, attribute, data_source, return_exact=False): # check the idr name and return the resource and possible values if idr_: idr_ = idr_.strip() @@ -47,7 +53,13 @@ def check_get_names(idr_, resource, attribute, return_exact=False): all_act_names = [] if pr_names: if not return_exact: - for data_source, pr_names_ in pr_names.items(): + for data_source_, pr_names_ in pr_names.items(): + if ( + data_source + and data_source != "all" + and data_source.lower() != data_source_.lower() + ): + continue act_name = [ name["id"] for name in pr_names_ @@ -57,7 +69,13 @@ def check_get_names(idr_, resource, attribute, return_exact=False): all_act_names = all_act_names + act_name else: # This should be modified to query specific data source specific - for data_source, pr_names_ in pr_names.items(): + for data_source_, pr_names_ in pr_names.items(): + if ( + data_source + and data_source != "all" + and data_source.lower() != data_source_.lower() + ): + continue act_name = [ name["id"] for name in pr_names_ @@ -68,7 +86,7 @@ def check_get_names(idr_, resource, attribute, return_exact=False): class QueryItem(object): - def __init__(self, filter, adjust_res=True): + def __init__(self, filter, data_source, adjust_res=True): """ define query and adjust resource if it is needed, e.g. name is provided @@ -82,6 +100,7 @@ def __init__(self, filter, adjust_res=True): self.name = filter.get("name") self.value = filter.get("value") self.operator = filter.get("operator") + self.data_source = data_source if filter.get("set_query_type") and filter.get("query_type"): self.query_type = filter.get("query_type") else: @@ -100,7 +119,9 @@ def adjust_resource(self): if mapping_names[self.resource].get(self.name): self.name = mapping_names[self.resource].get(self.name) if self.operator == "contains" or self.operator == "not_contains": - ac_value = check_get_names(self.value, self.resource, self.name) + ac_value = check_get_names( + self.value, self.resource, self.name, self.data_source + ) if len(ac_value) == 0: self.value = -1 elif len(ac_value) == 1: @@ -113,7 +134,7 @@ def adjust_resource(self): self.operator = "not_equals" else: ac_value = check_get_names( - self.value, self.resource, self.name, True + self.value, self.resource, self.name, self.data_source, True ) if ac_value and len(ac_value) == 1: self.value = ac_value[0] @@ -229,7 +250,7 @@ def get_image_non_image_query(self): query = {} query["main_attribute"] = main res = self.run_query(query, resource) - new_cond = get_ids(res, resource) + new_cond = get_ids(res, resource, self.data_source) if new_cond: if not main_or_attribute_.get(resource): main_or_attribute_[resource] = new_cond @@ -264,7 +285,7 @@ def get_image_non_image_query(self): query = {} query["or_filters"] = or_query res = self.run_query(query, resource) - new_cond = get_ids(res, resource) + new_cond = get_ids(res, resource, self.data_source) if new_cond: if not main_or_attribute_.get(resource): main_or_attribute_[resource] = new_cond @@ -301,7 +322,7 @@ def get_image_non_image_query(self): query = {} query["main_attribute"] = main res = self.run_query(query, resource) - new_cond = get_ids(res, resource) + new_cond = get_ids(res, resource, self.data_source) if new_cond: if not main_and_attribute.get(resource): main_and_attribute[resource] = new_cond @@ -323,7 +344,7 @@ def get_image_non_image_query(self): query = {} query["and_filters"] = and_query res = self.run_query(query, resource) - new_cond = get_ids(res, resource) + new_cond = get_ids(res, resource, self.data_source) if new_cond: if not main_and_attribute.get(resource): main_and_attribute[resource] = new_cond @@ -351,7 +372,6 @@ def get_image_non_image_query(self): def run_query(self, query_, resource): main_attributes = {} - query = {"and_filters": [], "or_filters": []} if query_.get("and_filters"): @@ -527,7 +547,7 @@ def combine_conds(curnt_cond, new_cond, resource): return returned_cond -def get_ids(results, resource): +def get_ids(results, resource, data_source): ids = [] if results.get("results") and results.get("results").get("results"): for item in results["results"]["results"]: @@ -537,7 +557,7 @@ def get_ids(results, resource): qur_item["value"] = item["id"] qur_item["operator"] = "equals" qur_item["resource"] = resource - qur_item_ = QueryItem(qur_item) + qur_item_ = QueryItem(qur_item, data_source) ids.append(qur_item_) else: qur_item = {} @@ -545,7 +565,7 @@ def get_ids(results, resource): qur_item["value"] = -1 qur_item["operator"] = "equals" qur_item["resource"] = resource - qur_item_ = QueryItem(qur_item) + qur_item_ = QueryItem(qur_item, data_source) ids.append(qur_item_) return ids @@ -657,6 +677,8 @@ def determine_search_results_( ): from omero_search_engine.api.v1.resources.utils import build_error_message + reset_global_values() + if query_.get("query_details"): case_sensitive = query_.get("query_details").get("case_sensitive") else: @@ -678,7 +700,7 @@ def determine_search_results_( if and_filters and len(and_filters) > 0: and_query_group = QueryGroup("and_filters") for filter in and_filters: - q_item = QueryItem(filter) + q_item = QueryItem(filter, data_source) # Check the name value and, if it is a list, # it will create a new or filter for them and move it # Please note it is working for and filter when there is not @@ -720,7 +742,7 @@ def determine_search_results_( or_query_groups.append(or_query_group) if isinstance(filters_, list): for filter in filters_: - q_item = QueryItem(filter) + q_item = QueryItem(filter, data_source) if q_item.query_type == "main_attribute" and ( filter["name"] == "description" ): @@ -739,7 +761,7 @@ def determine_search_results_( new_fil["operator"] = filter["operator"] new_fil["set_query_type"] = True new_fil["query_type"] = q_item.query_type - _q_item = QueryItem(new_fil) + _q_item = QueryItem(new_fil, data_source) or_query_group.add_query(_q_item) else: q_item.name = "id" @@ -776,6 +798,7 @@ def simple_search( data_source, return_containers=False, ): + reset_global_values() if not operator: operator = "equals" if key: