Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update recommendations.py #18304

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 58 additions & 41 deletions lib/galaxy/tools/recommendations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@

import h5py
import numpy as np
import requests
import yaml

from galaxy.tools.parameters import populate_state
from galaxy.tools.parameters.workflow_utils import workflow_building_modes
from galaxy.util import (
DEFAULT_SOCKET_TIMEOUT,
requests,
)
from galaxy.util import DEFAULT_SOCKET_TIMEOUT
from galaxy.workflow.modules import module_factory

log = logging.getLogger(__name__)
Expand All @@ -29,12 +27,12 @@ class ToolRecommendations:
def __init__(self):
self.tool_recommendation_model_path = None
self.admin_tool_recommendations_path = None
self.deprecated_tools = {}
self.admin_recommendations = {}
self.model_data_dictionary = {}
self.reverse_dictionary = {}
self.all_tools = {}
self.tool_weights_sorted = {}
self.deprecated_tools = dict()
self.admin_recommendations = dict()
self.model_data_dictionary = dict()
self.reverse_dictionary = dict()
self.all_tools = dict()
self.tool_weights_sorted = dict()
self.loaded_model = None
self.compatible_tools = None
self.standard_connections = None
Expand Down Expand Up @@ -62,7 +60,7 @@ def create_transformer_model(self, vocab_size):

class TransformerBlock(Layer):
def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
super().__init__()
super(TransformerBlock, self).__init__()
self.att = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim, dropout=rate)
self.ffn = Sequential([Dense(ff_dim, activation="relu"), Dense(embed_dim)])
self.layernorm1 = LayerNormalization(epsilon=1e-6)
Expand All @@ -82,7 +80,7 @@ def call(self, inputs, training):

class TokenAndPositionEmbedding(Layer):
def __init__(self, maxlen, vocab_size, embed_dim):
super().__init__()
super(TokenAndPositionEmbedding, self).__init__()
self.token_emb = Embedding(input_dim=vocab_size, output_dim=embed_dim, mask_zero=True)
self.pos_emb = Embedding(input_dim=maxlen, output_dim=embed_dim, mask_zero=True)

Expand Down Expand Up @@ -119,7 +117,7 @@ def get_predictions(self, trans, tool_sequence, remote_model_url):
"""
Compute tool predictions
"""
recommended_tools = {}
recommended_tools = dict()
self.__collect_admin_preferences(trans.app.config.admin_tool_recommendations_path)
if self.model_ok is None:
self.__set_model(trans, remote_model_url)
Expand All @@ -136,7 +134,7 @@ def __set_model(self, trans, remote_model_url):
self.loaded_model = self.create_transformer_model(len(self.reverse_dictionary) + 1)
self.loaded_model.load_weights(self.tool_recommendation_model_path)

self.model_data_dictionary = {v: k for k, v in self.reverse_dictionary.items()}
self.model_data_dictionary = dict((v, k) for k, v in self.reverse_dictionary.items())
# set the list of compatible tools
self.compatible_tools = json.loads(model_file["compatible_tools"][()].decode("utf-8"))
tool_weights = json.loads(model_file["class_weights"][()].decode("utf-8"))
Expand Down Expand Up @@ -199,8 +197,8 @@ def __get_tool_extensions(self, trans, tool_id):
module.recover_state(module_state)
inputs = module.get_all_inputs(connectable_only=True)
outputs = module.get_all_outputs()
input_extensions = []
output_extensions = []
input_extensions = list()
output_extensions = list()
for i_ext in inputs:
input_extensions.extend(i_ext["extensions"])
for o_ext in outputs:
Expand All @@ -212,34 +210,27 @@ def __filter_tool_predictions(self, trans, prediction_data, tool_ids, tool_score
Filter tool predictions based on datatype compatibility and tool connections.
Add admin preferences to recommendations.
"""
last_compatible_tools = []
if last_tool_name in self.model_data_dictionary:
last_tool_name_id = self.model_data_dictionary[last_tool_name]
if last_tool_name_id in self.compatible_tools:
last_compatible_tools = [
self.reverse_dictionary[t_id] for t_id in self.compatible_tools[last_tool_name_id]
]

prediction_data["is_deprecated"] = False
# get the list of datatype extensions of the last tool of the tool sequence
_, last_output_extensions = self.__get_tool_extensions(trans, self.all_tools[last_tool_name][0])
prediction_data["o_extensions"] = list(set(last_output_extensions))
t_ids_scores = zip(tool_ids, tool_scores)
# form the payload of the predicted tools to be shown
for child, score in t_ids_scores:
c_dict = {}
c_dict = dict()
for t_id in self.all_tools:
# select the name and tool id if it is installed in Galaxy
if (
t_id == child
and score >= 0.0
and t_id in last_compatible_tools
and child not in self.deprecated_tools
and child != last_tool_name
):
full_tool_id = self.all_tools[t_id][0]
pred_input_extensions, _ = self.__get_tool_extensions(trans, full_tool_id)
c_dict["name"] = self.all_tools[t_id][1]
c_dict["tool_id"] = full_tool_id
c_dict["tool_score"] = score
c_dict["i_extensions"] = list(set(pred_input_extensions))
prediction_data["children"].append(c_dict)
break
Expand Down Expand Up @@ -269,25 +260,46 @@ def __filter_tool_predictions(self, trans, prediction_data, tool_ids, tool_score
break
return prediction_data

def __get_predicted_tools(self, base_tools, predictions, topk):
def __get_predicted_tools(self, pub_tools, predictions, last_tool_name, topk):
"""
Get predicted tools. If predicted tools are less in number, combine them with published tools
"""
t_intersect = list(set(predictions).intersection(set(base_tools)))
t_diff = list(set(predictions).difference(set(base_tools)))
last_compatible_tools = list()
if last_tool_name in self.model_data_dictionary:
last_tool_name_id = self.model_data_dictionary[last_tool_name]
if last_tool_name_id in self.compatible_tools:
last_compatible_tools = [
self.reverse_dictionary[t_id] for t_id in self.compatible_tools[last_tool_name_id]
]
t_intersect = list(set(predictions).intersection(set(pub_tools)))
t_diff = list(set(predictions).difference(set(pub_tools)))
t_intersect, u_intersect = self.__sort_by_usage(
t_intersect, self.tool_weights_sorted, self.model_data_dictionary
)
t_diff, u_diff = self.__sort_by_usage(t_diff, self.tool_weights_sorted, self.model_data_dictionary)
t_intersect.extend(t_diff)
u_intersect.extend(u_diff)
t_diff, u_diff = self.__sort_by_usage(
t_diff, self.tool_weights_sorted, self.model_data_dictionary
)
t_intersect_compat = list(set(last_compatible_tools).intersection(set(t_diff)))
# filter against rare bad predictions for any tool
if len(t_intersect_compat) > 0:
t_compat, u_compat = self.__sort_by_usage(
t_intersect_compat, self.tool_weights_sorted, self.model_data_dictionary
)
else:
t_compat, u_compat = self.__sort_by_usage(
last_compatible_tools, self.tool_weights_sorted, self.model_data_dictionary
)
t_intersect.extend(t_compat)
u_intersect.extend(u_compat)
t_intersect = t_intersect[:topk]
u_intersect = u_intersect[:topk]
return t_intersect, u_intersect

def __sort_by_usage(self, t_list, class_weights, d_dict):
"""
Sort predictions by usage/class weights
"""
tool_dict = {}
tool_dict = dict()
for tool in t_list:
t_id = d_dict[tool]
tool_dict[tool] = class_weights[int(t_id)]
Expand All @@ -298,18 +310,26 @@ def __separate_predictions(self, base_tools, predictions, last_tool_name, weight
"""
Get predictions from published and normal workflows
"""
last_base_tools = []
last_base_tools = list()
weight_values = list(self.tool_weights_sorted.values())
wt_predictions = predictions * weight_values
prediction_pos = np.argsort(predictions, axis=-1)
topk_prediction_pos = prediction_pos[-topk:]
wt_prediction_pos = np.argsort(wt_predictions, axis=-1)
topk_prediction_pos = list(prediction_pos[-topk:])
wt_topk_prediction_pos = list(wt_prediction_pos[-topk:])
# get tool ids
wt_pred_tool_names = [self.reverse_dictionary[str(tool_pos)] for tool_pos in wt_topk_prediction_pos]
pred_tool_names = [self.reverse_dictionary[str(tool_pos)] for tool_pos in topk_prediction_pos]
# exclude same tool as the last tool
pred_tool_names.extend(wt_pred_tool_names)
pred_tool_names = [item for item in pred_tool_names if item != last_tool_name]
if last_tool_name in base_tools:
last_base_tools = base_tools[last_tool_name]
if type(last_base_tools).__name__ == "str":
# get published or compatible tools for the last tool in a sequence of tools
last_base_tools = last_base_tools.split(",")
# get predicted tools
sorted_c_t, sorted_c_v = self.__get_predicted_tools(last_base_tools, pred_tool_names, topk)
sorted_c_t, sorted_c_v = self.__get_predicted_tools(last_base_tools, pred_tool_names, last_tool_name, topk)
return sorted_c_t, sorted_c_v

def __compute_tool_prediction(self, trans, tool_sequence):
Expand All @@ -319,10 +339,10 @@ def __compute_tool_prediction(self, trans, tool_sequence):
Return an empty payload with just the tool sequence if anything goes wrong within the try block
"""
topk = trans.app.config.topk_recommendations
prediction_data = {}
prediction_data = dict()
tool_sequence = tool_sequence.split(",")[::-1]
prediction_data["name"] = ",".join(tool_sequence)
prediction_data["children"] = []
prediction_data["children"] = list()
last_tool_name = tool_sequence[-1]
# do prediction only if the last is present in the collections of tools
if last_tool_name in self.model_data_dictionary:
Expand Down Expand Up @@ -355,8 +375,5 @@ def __compute_tool_prediction(self, trans, tool_sequence):
pub_t, pub_v = self.__separate_predictions(
self.standard_connections, prediction, last_tool_name, weight_values, topk
)
# remove duplicates if any
pub_t = list(dict.fromkeys(pub_t))
pub_v = list(dict.fromkeys(pub_v))
prediction_data = self.__filter_tool_predictions(trans, prediction_data, pub_t, pub_v, last_tool_name)
return prediction_data
Loading