Skip to content

Commit

Permalink
Refactor get_tool_predictions operation
Browse files Browse the repository at this point in the history
  • Loading branch information
heisner-tillman committed Feb 20, 2024
1 parent 310fc85 commit 6a3da43
Showing 1 changed file with 24 additions and 24 deletions.
48 changes: 24 additions & 24 deletions lib/galaxy/webapps/galaxy/api/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@
SharingStatus,
WorkflowSortByEnum,
)
from galaxy.schema.workflows import InvokeWorkflowPayload
from galaxy.schema.workflows import (
GetToolPredictionsPayload,
InvokeWorkflowPayload,
)
from galaxy.structured_app import StructuredApp
from galaxy.tool_shed.galaxy_install.install_manager import InstallRepositoryManager
from galaxy.tools import recommendations
Expand Down Expand Up @@ -587,29 +590,6 @@ def build_module(self, trans: GalaxyWebTransaction, payload=None):
step_dict["tool_version"] = module.get_version()
return step_dict

@expose_api
def get_tool_predictions(self, trans: ProvidesUserContext, payload, **kwd):
"""
POST /api/workflows/get_tool_predictions
Fetch predicted tools for a workflow
:type payload: dict
:param payload:
a dictionary containing two parameters
'tool_sequence' - comma separated sequence of tool ids
'remote_model_url' - (optional) path to the deep learning model
"""
remote_model_url = payload.get("remote_model_url", trans.app.config.tool_recommendation_model_path)
tool_sequence = payload.get("tool_sequence", "")
if "tool_sequence" not in payload or remote_model_url is None:
return
tool_sequence, recommended_tools = self.tool_recommendations.get_predictions(
trans, tool_sequence, remote_model_url
)
return {"current_tool": tool_sequence, "predicted_data": recommended_tools}

#
# -- Helper methods --
#
Expand Down Expand Up @@ -909,6 +889,15 @@ def __get_stored_workflow(self, trans, workflow_id, **kwd):
),
]

GetToolPredictionsBody = Annotated[
GetToolPredictionsPayload,
Body(
default=...,
title="Get tool predictions",
description="The values to get tool predictions for a workflow.",
),
]


@router.cbv
class FastAPIWorkflows:
Expand Down Expand Up @@ -955,6 +944,17 @@ def index(
response.headers["total_matches"] = str(total_matches)
return workflows

@router.post(
"/api/workflows/get_tool_predictions",
summary="Fetch predicted tools for a workflow",
)
def get_tool_predictions(
self,
payload: GetToolPredictionsBody,
trans: ProvidesUserContext = DependsOnTrans,
):
return self.service.get_tool_predictions(trans, payload.model_dump(exclude_unset=True))

@router.get(
"/api/workflows/{workflow_id}/sharing",
summary="Get the current sharing status of the given item.",
Expand Down

0 comments on commit 6a3da43

Please sign in to comment.