diff --git a/lib/galaxy/webapps/galaxy/api/workflows.py b/lib/galaxy/webapps/galaxy/api/workflows.py index f5d506599cf8..6c3aad52eb7b 100644 --- a/lib/galaxy/webapps/galaxy/api/workflows.py +++ b/lib/galaxy/webapps/galaxy/api/workflows.py @@ -80,7 +80,10 @@ SharingStatus, WorkflowSortByEnum, ) -from galaxy.schema.workflows import InvokeWorkflowPayload +from galaxy.schema.workflows import ( + GetToolPredictionsPayload, + InvokeWorkflowPayload, +) from galaxy.structured_app import StructuredApp from galaxy.tool_shed.galaxy_install.install_manager import InstallRepositoryManager from galaxy.tools import recommendations @@ -587,29 +590,6 @@ def build_module(self, trans: GalaxyWebTransaction, payload=None): step_dict["tool_version"] = module.get_version() return step_dict - @expose_api - def get_tool_predictions(self, trans: ProvidesUserContext, payload, **kwd): - """ - POST /api/workflows/get_tool_predictions - - Fetch predicted tools for a workflow - - :type payload: dict - :param payload: - - a dictionary containing two parameters - 'tool_sequence' - comma separated sequence of tool ids - 'remote_model_url' - (optional) path to the deep learning model - """ - remote_model_url = payload.get("remote_model_url", trans.app.config.tool_recommendation_model_path) - tool_sequence = payload.get("tool_sequence", "") - if "tool_sequence" not in payload or remote_model_url is None: - return - tool_sequence, recommended_tools = self.tool_recommendations.get_predictions( - trans, tool_sequence, remote_model_url - ) - return {"current_tool": tool_sequence, "predicted_data": recommended_tools} - # # -- Helper methods -- # @@ -909,6 +889,15 @@ def __get_stored_workflow(self, trans, workflow_id, **kwd): ), ] +GetToolPredictionsBody = Annotated[ + GetToolPredictionsPayload, + Body( + default=..., + title="Get tool predictions", + description="The values to get tool predictions for a workflow.", + ), +] + @router.cbv class FastAPIWorkflows: @@ -955,6 +944,17 @@ def index( response.headers["total_matches"] = str(total_matches) return workflows + @router.post( + "/api/workflows/get_tool_predictions", + summary="Fetch predicted tools for a workflow", + ) + def get_tool_predictions( + self, + payload: GetToolPredictionsBody, + trans: ProvidesUserContext = DependsOnTrans, + ): + return self.service.get_tool_predictions(trans, payload.model_dump(exclude_unset=True)) + @router.get( "/api/workflows/{workflow_id}/sharing", summary="Get the current sharing status of the given item.",