From 31ccccaf087119379ef9584487bed645b15caa54 Mon Sep 17 00:00:00 2001 From: heisner-tillman Date: Mon, 19 Feb 2024 21:31:57 +0100 Subject: [PATCH] Create service method for get_tool_predictions operation from WorkflowAPI --- .../webapps/galaxy/services/workflows.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/lib/galaxy/webapps/galaxy/services/workflows.py b/lib/galaxy/webapps/galaxy/services/workflows.py index 798bda20a576..5db50e7292a1 100644 --- a/lib/galaxy/webapps/galaxy/services/workflows.py +++ b/lib/galaxy/webapps/galaxy/services/workflows.py @@ -25,7 +25,8 @@ InvocationsStateCounts, WorkflowIndexQueryPayload, ) -from galaxy.schema.workflows import InvokeWorkflowPayload +from galaxy.schema.workflows import InvokeWorkflowPayload # GetToolPredictionsPayload, +from galaxy.tools import recommendations from galaxy.util.tool_shed.tool_shed_registry import Registry from galaxy.webapps.galaxy.services.base import ServiceBase from galaxy.webapps.galaxy.services.sharable import ShareableService @@ -53,6 +54,21 @@ def __init__( self._serializer = serializer self.shareable_service = ShareableService(workflows_manager, serializer, notification_manager) self._tool_shed_registry = tool_shed_registry + self.tool_recommendations = recommendations.ToolRecommendations() + + def get_tool_predictions( + self, + trans: ProvidesUserContext, + payload, + ): + remote_model_url = payload.get("remote_model_url", trans.app.config.tool_recommendation_model_path) + tool_sequence = payload.get("tool_sequence", "") + if "tool_sequence" not in payload or remote_model_url is None: + return {} + tool_sequence, recommended_tools = self.tool_recommendations.get_predictions( + trans, tool_sequence, remote_model_url + ) + return {"current_tool": tool_sequence, "predicted_data": recommended_tools} def index( self,