From 6e651727dec5fdb25162decb48797ceefed66395 Mon Sep 17 00:00:00 2001 From: Anup Kumar Date: Mon, 3 Jun 2024 15:53:29 +0200 Subject: [PATCH] Merge changes from Main and Galaxy EU --- lib/galaxy/tools/recommendations.py | 59 +++++++++++++++++++---------- 1 file changed, 39 insertions(+), 20 deletions(-) diff --git a/lib/galaxy/tools/recommendations.py b/lib/galaxy/tools/recommendations.py index 00d8490a9737..c37a8876c72f 100644 --- a/lib/galaxy/tools/recommendations.py +++ b/lib/galaxy/tools/recommendations.py @@ -212,14 +212,6 @@ def __filter_tool_predictions(self, trans, prediction_data, tool_ids, tool_score Filter tool predictions based on datatype compatibility and tool connections. Add admin preferences to recommendations. """ - last_compatible_tools = [] - if last_tool_name in self.model_data_dictionary: - last_tool_name_id = self.model_data_dictionary[last_tool_name] - if last_tool_name_id in self.compatible_tools: - last_compatible_tools = [ - self.reverse_dictionary[t_id] for t_id in self.compatible_tools[last_tool_name_id] - ] - prediction_data["is_deprecated"] = False # get the list of datatype extensions of the last tool of the tool sequence _, last_output_extensions = self.__get_tool_extensions(trans, self.all_tools[last_tool_name][0]) @@ -233,13 +225,14 @@ def __filter_tool_predictions(self, trans, prediction_data, tool_ids, tool_score if ( t_id == child and score >= 0.0 - and t_id in last_compatible_tools and child not in self.deprecated_tools + and child != last_tool_name ): full_tool_id = self.all_tools[t_id][0] pred_input_extensions, _ = self.__get_tool_extensions(trans, full_tool_id) c_dict["name"] = self.all_tools[t_id][1] c_dict["tool_id"] = full_tool_id + c_dict["tool_score"] = score c_dict["i_extensions"] = list(set(pred_input_extensions)) prediction_data["children"].append(c_dict) break @@ -269,18 +262,39 @@ def __filter_tool_predictions(self, trans, prediction_data, tool_ids, tool_score break return prediction_data - def __get_predicted_tools(self, base_tools, predictions, topk): + def __get_predicted_tools(self, pub_tools, predictions, last_tool_name, topk): """ Get predicted tools. If predicted tools are less in number, combine them with published tools """ - t_intersect = list(set(predictions).intersection(set(base_tools))) - t_diff = list(set(predictions).difference(set(base_tools))) + last_compatible_tools = list() + if last_tool_name in self.model_data_dictionary: + last_tool_name_id = self.model_data_dictionary[last_tool_name] + if last_tool_name_id in self.compatible_tools: + last_compatible_tools = [ + self.reverse_dictionary[t_id] for t_id in self.compatible_tools[last_tool_name_id] + ] + t_intersect = list(set(predictions).intersection(set(pub_tools))) + t_diff = list(set(predictions).difference(set(pub_tools))) t_intersect, u_intersect = self.__sort_by_usage( t_intersect, self.tool_weights_sorted, self.model_data_dictionary ) - t_diff, u_diff = self.__sort_by_usage(t_diff, self.tool_weights_sorted, self.model_data_dictionary) - t_intersect.extend(t_diff) - u_intersect.extend(u_diff) + t_diff, u_diff = self.__sort_by_usage( + t_diff, self.tool_weights_sorted, self.model_data_dictionary + ) + t_intersect_compat = list(set(last_compatible_tools).intersection(set(t_diff))) + # filter against rare bad predictions for any tool + if len(t_intersect_compat) > 0: + t_compat, u_compat = self.__sort_by_usage( + t_intersect_compat, self.tool_weights_sorted, self.model_data_dictionary + ) + else: + t_compat, u_compat = self.__sort_by_usage( + last_compatible_tools, self.tool_weights_sorted, self.model_data_dictionary + ) + t_intersect.extend(t_compat) + u_intersect.extend(u_compat) + t_intersect = t_intersect[:topk] + u_intersect = u_intersect[:topk] return t_intersect, u_intersect def __sort_by_usage(self, t_list, class_weights, d_dict): @@ -299,17 +313,25 @@ def __separate_predictions(self, base_tools, predictions, last_tool_name, weight Get predictions from published and normal workflows """ last_base_tools = [] + weight_values = list(self.tool_weights_sorted.values()) + wt_predictions = predictions * weight_values prediction_pos = np.argsort(predictions, axis=-1) - topk_prediction_pos = prediction_pos[-topk:] + wt_prediction_pos = np.argsort(wt_predictions, axis=-1) + topk_prediction_pos = list(prediction_pos[-topk:]) + wt_topk_prediction_pos = list(wt_prediction_pos[-topk:]) # get tool ids + wt_pred_tool_names = [self.reverse_dictionary[str(tool_pos)] for tool_pos in wt_topk_prediction_pos] pred_tool_names = [self.reverse_dictionary[str(tool_pos)] for tool_pos in topk_prediction_pos] + # exclude same tool as the last tool + pred_tool_names.extend(wt_pred_tool_names) + pred_tool_names = [item for item in pred_tool_names if item != last_tool_name] if last_tool_name in base_tools: last_base_tools = base_tools[last_tool_name] if type(last_base_tools).__name__ == "str": # get published or compatible tools for the last tool in a sequence of tools last_base_tools = last_base_tools.split(",") # get predicted tools - sorted_c_t, sorted_c_v = self.__get_predicted_tools(last_base_tools, pred_tool_names, topk) + sorted_c_t, sorted_c_v = self.__get_predicted_tools(last_base_tools, pred_tool_names, last_tool_name, topk) return sorted_c_t, sorted_c_v def __compute_tool_prediction(self, trans, tool_sequence): @@ -355,8 +377,5 @@ def __compute_tool_prediction(self, trans, tool_sequence): pub_t, pub_v = self.__separate_predictions( self.standard_connections, prediction, last_tool_name, weight_values, topk ) - # remove duplicates if any - pub_t = list(dict.fromkeys(pub_t)) - pub_v = list(dict.fromkeys(pub_v)) prediction_data = self.__filter_tool_predictions(trans, prediction_data, pub_t, pub_v, last_tool_name) return prediction_data