Skip to content

Commit

Permalink
Merge pull request #18305 from anuprulez/patch-7
Browse files Browse the repository at this point in the history
[24.1] Adapt Tool prediction API to Transformer-based deep learning architecture
  • Loading branch information
davelopez authored Jun 4, 2024
2 parents 397d7d3 + c46d09b commit 4f94e2f
Showing 1 changed file with 36 additions and 24 deletions.
60 changes: 36 additions & 24 deletions lib/galaxy/tools/recommendations.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,14 +212,6 @@ def __filter_tool_predictions(self, trans, prediction_data, tool_ids, tool_score
Filter tool predictions based on datatype compatibility and tool connections.
Add admin preferences to recommendations.
"""
last_compatible_tools = []
if last_tool_name in self.model_data_dictionary:
last_tool_name_id = self.model_data_dictionary[last_tool_name]
if last_tool_name_id in self.compatible_tools:
last_compatible_tools = [
self.reverse_dictionary[t_id] for t_id in self.compatible_tools[last_tool_name_id]
]

prediction_data["is_deprecated"] = False
# get the list of datatype extensions of the last tool of the tool sequence
_, last_output_extensions = self.__get_tool_extensions(trans, self.all_tools[last_tool_name][0])
Expand All @@ -230,16 +222,12 @@ def __filter_tool_predictions(self, trans, prediction_data, tool_ids, tool_score
c_dict = {}
for t_id in self.all_tools:
# select the name and tool id if it is installed in Galaxy
if (
t_id == child
and score >= 0.0
and t_id in last_compatible_tools
and child not in self.deprecated_tools
):
if t_id == child and score >= 0.0 and child not in self.deprecated_tools and child != last_tool_name:
full_tool_id = self.all_tools[t_id][0]
pred_input_extensions, _ = self.__get_tool_extensions(trans, full_tool_id)
c_dict["name"] = self.all_tools[t_id][1]
c_dict["tool_id"] = full_tool_id
c_dict["tool_score"] = score
c_dict["i_extensions"] = list(set(pred_input_extensions))
prediction_data["children"].append(c_dict)
break
Expand Down Expand Up @@ -269,18 +257,37 @@ def __filter_tool_predictions(self, trans, prediction_data, tool_ids, tool_score
break
return prediction_data

def __get_predicted_tools(self, base_tools, predictions, topk):
def __get_predicted_tools(self, pub_tools, predictions, last_tool_name, topk):
"""
Get predicted tools. If predicted tools are less in number, combine them with published tools
"""
t_intersect = list(set(predictions).intersection(set(base_tools)))
t_diff = list(set(predictions).difference(set(base_tools)))
last_compatible_tools = []
if last_tool_name in self.model_data_dictionary:
last_tool_name_id = self.model_data_dictionary[last_tool_name]
if last_tool_name_id in self.compatible_tools:
last_compatible_tools = [
self.reverse_dictionary[t_id] for t_id in self.compatible_tools[last_tool_name_id]
]
t_intersect = list(set(predictions).intersection(set(pub_tools)))
t_diff = list(set(predictions).difference(set(pub_tools)))
t_intersect, u_intersect = self.__sort_by_usage(
t_intersect, self.tool_weights_sorted, self.model_data_dictionary
)
t_diff, u_diff = self.__sort_by_usage(t_diff, self.tool_weights_sorted, self.model_data_dictionary)
t_intersect.extend(t_diff)
u_intersect.extend(u_diff)
t_intersect_compat = list(set(last_compatible_tools).intersection(set(t_diff)))
# filter against rare bad predictions for any tool
if len(t_intersect_compat) > 0:
t_compat, u_compat = self.__sort_by_usage(
t_intersect_compat, self.tool_weights_sorted, self.model_data_dictionary
)
else:
t_compat, u_compat = self.__sort_by_usage(
last_compatible_tools, self.tool_weights_sorted, self.model_data_dictionary
)
t_intersect.extend(t_compat)
u_intersect.extend(u_compat)
t_intersect = t_intersect[:topk]
u_intersect = u_intersect[:topk]
return t_intersect, u_intersect

def __sort_by_usage(self, t_list, class_weights, d_dict):
Expand All @@ -299,17 +306,25 @@ def __separate_predictions(self, base_tools, predictions, last_tool_name, weight
Get predictions from published and normal workflows
"""
last_base_tools = []
weight_values = list(self.tool_weights_sorted.values())
wt_predictions = predictions * weight_values
prediction_pos = np.argsort(predictions, axis=-1)
topk_prediction_pos = prediction_pos[-topk:]
wt_prediction_pos = np.argsort(wt_predictions, axis=-1)
topk_prediction_pos = list(prediction_pos[-topk:])
wt_topk_prediction_pos = list(wt_prediction_pos[-topk:])
# get tool ids
wt_pred_tool_names = [self.reverse_dictionary[str(tool_pos)] for tool_pos in wt_topk_prediction_pos]
pred_tool_names = [self.reverse_dictionary[str(tool_pos)] for tool_pos in topk_prediction_pos]
# exclude same tool as the last tool
pred_tool_names.extend(wt_pred_tool_names)
pred_tool_names = [item for item in pred_tool_names if item != last_tool_name]
if last_tool_name in base_tools:
last_base_tools = base_tools[last_tool_name]
if type(last_base_tools).__name__ == "str":
# get published or compatible tools for the last tool in a sequence of tools
last_base_tools = last_base_tools.split(",")
# get predicted tools
sorted_c_t, sorted_c_v = self.__get_predicted_tools(last_base_tools, pred_tool_names, topk)
sorted_c_t, sorted_c_v = self.__get_predicted_tools(last_base_tools, pred_tool_names, last_tool_name, topk)
return sorted_c_t, sorted_c_v

def __compute_tool_prediction(self, trans, tool_sequence):
Expand Down Expand Up @@ -355,8 +370,5 @@ def __compute_tool_prediction(self, trans, tool_sequence):
pub_t, pub_v = self.__separate_predictions(
self.standard_connections, prediction, last_tool_name, weight_values, topk
)
# remove duplicates if any
pub_t = list(dict.fromkeys(pub_t))
pub_v = list(dict.fromkeys(pub_v))
prediction_data = self.__filter_tool_predictions(trans, prediction_data, pub_t, pub_v, last_tool_name)
return prediction_data

0 comments on commit 4f94e2f

Please sign in to comment.