diff --git a/lib/galaxy/tools/recommendations.py b/lib/galaxy/tools/recommendations.py index 00d8490a9737..18e8933163c0 100644 --- a/lib/galaxy/tools/recommendations.py +++ b/lib/galaxy/tools/recommendations.py @@ -6,14 +6,12 @@ import h5py import numpy as np +import requests import yaml from galaxy.tools.parameters import populate_state from galaxy.tools.parameters.workflow_utils import workflow_building_modes -from galaxy.util import ( - DEFAULT_SOCKET_TIMEOUT, - requests, -) +from galaxy.util import DEFAULT_SOCKET_TIMEOUT from galaxy.workflow.modules import module_factory log = logging.getLogger(__name__) @@ -29,12 +27,12 @@ class ToolRecommendations: def __init__(self): self.tool_recommendation_model_path = None self.admin_tool_recommendations_path = None - self.deprecated_tools = {} - self.admin_recommendations = {} - self.model_data_dictionary = {} - self.reverse_dictionary = {} - self.all_tools = {} - self.tool_weights_sorted = {} + self.deprecated_tools = dict() + self.admin_recommendations = dict() + self.model_data_dictionary = dict() + self.reverse_dictionary = dict() + self.all_tools = dict() + self.tool_weights_sorted = dict() self.loaded_model = None self.compatible_tools = None self.standard_connections = None @@ -62,7 +60,7 @@ def create_transformer_model(self, vocab_size): class TransformerBlock(Layer): def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1): - super().__init__() + super(TransformerBlock, self).__init__() self.att = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim, dropout=rate) self.ffn = Sequential([Dense(ff_dim, activation="relu"), Dense(embed_dim)]) self.layernorm1 = LayerNormalization(epsilon=1e-6) @@ -82,7 +80,7 @@ def call(self, inputs, training): class TokenAndPositionEmbedding(Layer): def __init__(self, maxlen, vocab_size, embed_dim): - super().__init__() + super(TokenAndPositionEmbedding, self).__init__() self.token_emb = Embedding(input_dim=vocab_size, output_dim=embed_dim, mask_zero=True) self.pos_emb = Embedding(input_dim=maxlen, output_dim=embed_dim, mask_zero=True) @@ -119,7 +117,7 @@ def get_predictions(self, trans, tool_sequence, remote_model_url): """ Compute tool predictions """ - recommended_tools = {} + recommended_tools = dict() self.__collect_admin_preferences(trans.app.config.admin_tool_recommendations_path) if self.model_ok is None: self.__set_model(trans, remote_model_url) @@ -136,7 +134,7 @@ def __set_model(self, trans, remote_model_url): self.loaded_model = self.create_transformer_model(len(self.reverse_dictionary) + 1) self.loaded_model.load_weights(self.tool_recommendation_model_path) - self.model_data_dictionary = {v: k for k, v in self.reverse_dictionary.items()} + self.model_data_dictionary = dict((v, k) for k, v in self.reverse_dictionary.items()) # set the list of compatible tools self.compatible_tools = json.loads(model_file["compatible_tools"][()].decode("utf-8")) tool_weights = json.loads(model_file["class_weights"][()].decode("utf-8")) @@ -199,8 +197,8 @@ def __get_tool_extensions(self, trans, tool_id): module.recover_state(module_state) inputs = module.get_all_inputs(connectable_only=True) outputs = module.get_all_outputs() - input_extensions = [] - output_extensions = [] + input_extensions = list() + output_extensions = list() for i_ext in inputs: input_extensions.extend(i_ext["extensions"]) for o_ext in outputs: @@ -212,14 +210,6 @@ def __filter_tool_predictions(self, trans, prediction_data, tool_ids, tool_score Filter tool predictions based on datatype compatibility and tool connections. Add admin preferences to recommendations. """ - last_compatible_tools = [] - if last_tool_name in self.model_data_dictionary: - last_tool_name_id = self.model_data_dictionary[last_tool_name] - if last_tool_name_id in self.compatible_tools: - last_compatible_tools = [ - self.reverse_dictionary[t_id] for t_id in self.compatible_tools[last_tool_name_id] - ] - prediction_data["is_deprecated"] = False # get the list of datatype extensions of the last tool of the tool sequence _, last_output_extensions = self.__get_tool_extensions(trans, self.all_tools[last_tool_name][0]) @@ -227,19 +217,20 @@ def __filter_tool_predictions(self, trans, prediction_data, tool_ids, tool_score t_ids_scores = zip(tool_ids, tool_scores) # form the payload of the predicted tools to be shown for child, score in t_ids_scores: - c_dict = {} + c_dict = dict() for t_id in self.all_tools: # select the name and tool id if it is installed in Galaxy if ( t_id == child and score >= 0.0 - and t_id in last_compatible_tools and child not in self.deprecated_tools + and child != last_tool_name ): full_tool_id = self.all_tools[t_id][0] pred_input_extensions, _ = self.__get_tool_extensions(trans, full_tool_id) c_dict["name"] = self.all_tools[t_id][1] c_dict["tool_id"] = full_tool_id + c_dict["tool_score"] = score c_dict["i_extensions"] = list(set(pred_input_extensions)) prediction_data["children"].append(c_dict) break @@ -269,25 +260,46 @@ def __filter_tool_predictions(self, trans, prediction_data, tool_ids, tool_score break return prediction_data - def __get_predicted_tools(self, base_tools, predictions, topk): + def __get_predicted_tools(self, pub_tools, predictions, last_tool_name, topk): """ Get predicted tools. If predicted tools are less in number, combine them with published tools """ - t_intersect = list(set(predictions).intersection(set(base_tools))) - t_diff = list(set(predictions).difference(set(base_tools))) + last_compatible_tools = list() + if last_tool_name in self.model_data_dictionary: + last_tool_name_id = self.model_data_dictionary[last_tool_name] + if last_tool_name_id in self.compatible_tools: + last_compatible_tools = [ + self.reverse_dictionary[t_id] for t_id in self.compatible_tools[last_tool_name_id] + ] + t_intersect = list(set(predictions).intersection(set(pub_tools))) + t_diff = list(set(predictions).difference(set(pub_tools))) t_intersect, u_intersect = self.__sort_by_usage( t_intersect, self.tool_weights_sorted, self.model_data_dictionary ) - t_diff, u_diff = self.__sort_by_usage(t_diff, self.tool_weights_sorted, self.model_data_dictionary) - t_intersect.extend(t_diff) - u_intersect.extend(u_diff) + t_diff, u_diff = self.__sort_by_usage( + t_diff, self.tool_weights_sorted, self.model_data_dictionary + ) + t_intersect_compat = list(set(last_compatible_tools).intersection(set(t_diff))) + # filter against rare bad predictions for any tool + if len(t_intersect_compat) > 0: + t_compat, u_compat = self.__sort_by_usage( + t_intersect_compat, self.tool_weights_sorted, self.model_data_dictionary + ) + else: + t_compat, u_compat = self.__sort_by_usage( + last_compatible_tools, self.tool_weights_sorted, self.model_data_dictionary + ) + t_intersect.extend(t_compat) + u_intersect.extend(u_compat) + t_intersect = t_intersect[:topk] + u_intersect = u_intersect[:topk] return t_intersect, u_intersect def __sort_by_usage(self, t_list, class_weights, d_dict): """ Sort predictions by usage/class weights """ - tool_dict = {} + tool_dict = dict() for tool in t_list: t_id = d_dict[tool] tool_dict[tool] = class_weights[int(t_id)] @@ -298,18 +310,26 @@ def __separate_predictions(self, base_tools, predictions, last_tool_name, weight """ Get predictions from published and normal workflows """ - last_base_tools = [] + last_base_tools = list() + weight_values = list(self.tool_weights_sorted.values()) + wt_predictions = predictions * weight_values prediction_pos = np.argsort(predictions, axis=-1) - topk_prediction_pos = prediction_pos[-topk:] + wt_prediction_pos = np.argsort(wt_predictions, axis=-1) + topk_prediction_pos = list(prediction_pos[-topk:]) + wt_topk_prediction_pos = list(wt_prediction_pos[-topk:]) # get tool ids + wt_pred_tool_names = [self.reverse_dictionary[str(tool_pos)] for tool_pos in wt_topk_prediction_pos] pred_tool_names = [self.reverse_dictionary[str(tool_pos)] for tool_pos in topk_prediction_pos] + # exclude same tool as the last tool + pred_tool_names.extend(wt_pred_tool_names) + pred_tool_names = [item for item in pred_tool_names if item != last_tool_name] if last_tool_name in base_tools: last_base_tools = base_tools[last_tool_name] if type(last_base_tools).__name__ == "str": # get published or compatible tools for the last tool in a sequence of tools last_base_tools = last_base_tools.split(",") # get predicted tools - sorted_c_t, sorted_c_v = self.__get_predicted_tools(last_base_tools, pred_tool_names, topk) + sorted_c_t, sorted_c_v = self.__get_predicted_tools(last_base_tools, pred_tool_names, last_tool_name, topk) return sorted_c_t, sorted_c_v def __compute_tool_prediction(self, trans, tool_sequence): @@ -319,10 +339,10 @@ def __compute_tool_prediction(self, trans, tool_sequence): Return an empty payload with just the tool sequence if anything goes wrong within the try block """ topk = trans.app.config.topk_recommendations - prediction_data = {} + prediction_data = dict() tool_sequence = tool_sequence.split(",")[::-1] prediction_data["name"] = ",".join(tool_sequence) - prediction_data["children"] = [] + prediction_data["children"] = list() last_tool_name = tool_sequence[-1] # do prediction only if the last is present in the collections of tools if last_tool_name in self.model_data_dictionary: @@ -355,8 +375,5 @@ def __compute_tool_prediction(self, trans, tool_sequence): pub_t, pub_v = self.__separate_predictions( self.standard_connections, prediction, last_tool_name, weight_values, topk ) - # remove duplicates if any - pub_t = list(dict.fromkeys(pub_t)) - pub_v = list(dict.fromkeys(pub_v)) prediction_data = self.__filter_tool_predictions(trans, prediction_data, pub_t, pub_v, last_tool_name) return prediction_data