diff --git a/lib/galaxy/webapps/galaxy/services/workflows.py b/lib/galaxy/webapps/galaxy/services/workflows.py index 798bda20a576..5db50e7292a1 100644 --- a/lib/galaxy/webapps/galaxy/services/workflows.py +++ b/lib/galaxy/webapps/galaxy/services/workflows.py @@ -25,7 +25,8 @@ InvocationsStateCounts, WorkflowIndexQueryPayload, ) -from galaxy.schema.workflows import InvokeWorkflowPayload +from galaxy.schema.workflows import InvokeWorkflowPayload # GetToolPredictionsPayload, +from galaxy.tools import recommendations from galaxy.util.tool_shed.tool_shed_registry import Registry from galaxy.webapps.galaxy.services.base import ServiceBase from galaxy.webapps.galaxy.services.sharable import ShareableService @@ -53,6 +54,21 @@ def __init__( self._serializer = serializer self.shareable_service = ShareableService(workflows_manager, serializer, notification_manager) self._tool_shed_registry = tool_shed_registry + self.tool_recommendations = recommendations.ToolRecommendations() + + def get_tool_predictions( + self, + trans: ProvidesUserContext, + payload, + ): + remote_model_url = payload.get("remote_model_url", trans.app.config.tool_recommendation_model_path) + tool_sequence = payload.get("tool_sequence", "") + if "tool_sequence" not in payload or remote_model_url is None: + return {} + tool_sequence, recommended_tools = self.tool_recommendations.get_predictions( + trans, tool_sequence, remote_model_url + ) + return {"current_tool": tool_sequence, "predicted_data": recommended_tools} def index( self,