Skip to content

Commit

Permalink
Refactor execute operation to receive files in payload and adjust the…
Browse files Browse the repository at this point in the history
… corresponding service methods
  • Loading branch information
heisner-tillman committed Feb 8, 2024
1 parent e6d144c commit 81f9e04
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 54 deletions.
28 changes: 3 additions & 25 deletions lib/galaxy/webapps/galaxy/api/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
APIContentTypeRoute,
as_form,
BaseGalaxyAPIController,
depend_on_either_json_or_form_data,
depends,
DependsOnTrans,
Router,
Expand Down Expand Up @@ -121,37 +122,14 @@ async def fetch_form(
@router.post(
"/api/tools",
summary="Execute tool with a given parameter payload",
route_class_override=JsonApiRoute,
)
async def execute_json(
def execute(
self,
payload: CreateToolBody,
payload: ExecuteToolPayload = depend_on_either_json_or_form_data(ExecuteToolPayload),
trans: ProvidesHistoryContext = DependsOnTrans,
) -> ToolResponse:
return self.service.execute(trans, payload)

@router.post(
"/api/tools",
summary="Execute tool with a given parameter payload",
route_class_override=FormDataApiRoute,
)
async def execute_form(
self,
request: Request,
payload: ExecuteToolPayload = Depends(CreateDataForm.as_form),
files: Optional[List[UploadFile]] = None,
trans: ProvidesHistoryContext = DependsOnTrans,
) -> ToolResponse:
files2: List[StarletteUploadFile] = cast(List[StarletteUploadFile], files or [])

# FastAPI's UploadFile is a very light wrapper around starlette's UploadFile
if not files2:
data = await request.form()
for value in data.values():
if isinstance(value, StarletteUploadFile):
files2.append(value)
return self.service.execute(trans, payload, files2)


class ToolsController(BaseGalaxyAPIController, UsesVisualizationMixin):
"""
Expand Down
43 changes: 14 additions & 29 deletions lib/galaxy/webapps/galaxy/services/tools.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import re
import shutil
import tempfile
from json import dumps
Expand Down Expand Up @@ -78,30 +77,13 @@ def create_temp_file(self, trans, files) -> Dict[str, FilesPayload]:
)
return files_payload

def create_temp_file_execute(self, trans, files) -> Dict[str, FilesPayload]:
# TODO - could access headers from request maybe better to do that
# header_content = request.headers["content-type"]
files_payload = {}
for i, upload_file in enumerate(files):
with tempfile.NamedTemporaryFile(
dir=trans.app.config.new_file_path, prefix="upload_file_data_", delete=False
) as dest:
shutil.copyfileobj(upload_file.file, dest) # type: ignore[misc] # https://github.com/python/mypy/issues/15031
upload_file.file.close()

# try to grab the name from the header
try:
header_items = upload_file.headers.items()[0]
name_search = re.search(r'name="([^"]+)"', header_items[1])
if name_search:
name = name_search.group(1)
else:
raise Exception("No name found in header")
# if we can't find the name in the header use the index
except Exception:
name = f"files_{i}|file_data"
files_payload[name] = FilesPayload(filename=upload_file.filename, local_filename=dest.name)
return files_payload
def create_temp_file_execute(self, trans, upload_file) -> FilesPayload:
with tempfile.NamedTemporaryFile(
dir=trans.app.config.new_file_path, prefix="upload_file_data_", delete=False
) as dest:
shutil.copyfileobj(upload_file.file, dest) # type: ignore[misc] # https://github.com/python/mypy/issues/15031
upload_file.file.close()
return FilesPayload(filename=upload_file.filename, local_filename=dest.name)

def create_fetch(
self,
Expand Down Expand Up @@ -142,7 +124,6 @@ def execute(
self,
trans: ProvidesHistoryContext,
payload: ExecuteToolPayload,
files: Optional[List[UploadFile]] = None,
) -> ToolResponse:
tool_id = payload.tool_id
tool_uuid = payload.tool_uuid
Expand All @@ -152,10 +133,14 @@ def execute(
)
if tool_id is None and tool_uuid is None:
raise HTTPException(status_code=400, detail="Must specify a valid tool_id to use this endpoint.")

create_payload = payload.model_dump(exclude_unset=True)
if files:
create_payload.update(self.create_temp_file_execute(trans, files))

# create temporary files from the uploaded files
files = {}
for key in list(create_payload.keys()):
if key.startswith("files_") or key.startswith("__files_"):
files[key] = self.create_temp_file_execute(trans, create_payload.pop(key))
create_payload.update(files)
return self._create(trans, create_payload)

def _create(self, trans: ProvidesHistoryContext, payload) -> ToolResponse:
Expand Down

0 comments on commit 81f9e04

Please sign in to comment.