From 81f9e0417c18728e63e626644e62019d3146a47a Mon Sep 17 00:00:00 2001 From: heisner-tillman Date: Thu, 8 Feb 2024 11:18:01 +0100 Subject: [PATCH] Refactor execute operation to receive files in payload and adjust the corresponding service methods --- lib/galaxy/webapps/galaxy/api/tools.py | 28 ++------------ lib/galaxy/webapps/galaxy/services/tools.py | 43 +++++++-------------- 2 files changed, 17 insertions(+), 54 deletions(-) diff --git a/lib/galaxy/webapps/galaxy/api/tools.py b/lib/galaxy/webapps/galaxy/api/tools.py index 455a68809f31..32372be3631f 100644 --- a/lib/galaxy/webapps/galaxy/api/tools.py +++ b/lib/galaxy/webapps/galaxy/api/tools.py @@ -51,6 +51,7 @@ APIContentTypeRoute, as_form, BaseGalaxyAPIController, + depend_on_either_json_or_form_data, depends, DependsOnTrans, Router, @@ -121,37 +122,14 @@ async def fetch_form( @router.post( "/api/tools", summary="Execute tool with a given parameter payload", - route_class_override=JsonApiRoute, ) - async def execute_json( + def execute( self, - payload: CreateToolBody, + payload: ExecuteToolPayload = depend_on_either_json_or_form_data(ExecuteToolPayload), trans: ProvidesHistoryContext = DependsOnTrans, ) -> ToolResponse: return self.service.execute(trans, payload) - @router.post( - "/api/tools", - summary="Execute tool with a given parameter payload", - route_class_override=FormDataApiRoute, - ) - async def execute_form( - self, - request: Request, - payload: ExecuteToolPayload = Depends(CreateDataForm.as_form), - files: Optional[List[UploadFile]] = None, - trans: ProvidesHistoryContext = DependsOnTrans, - ) -> ToolResponse: - files2: List[StarletteUploadFile] = cast(List[StarletteUploadFile], files or []) - - # FastAPI's UploadFile is a very light wrapper around starlette's UploadFile - if not files2: - data = await request.form() - for value in data.values(): - if isinstance(value, StarletteUploadFile): - files2.append(value) - return self.service.execute(trans, payload, files2) - class ToolsController(BaseGalaxyAPIController, UsesVisualizationMixin): """ diff --git a/lib/galaxy/webapps/galaxy/services/tools.py b/lib/galaxy/webapps/galaxy/services/tools.py index 595dbc09da05..9053cccce56a 100644 --- a/lib/galaxy/webapps/galaxy/services/tools.py +++ b/lib/galaxy/webapps/galaxy/services/tools.py @@ -1,5 +1,4 @@ import logging -import re import shutil import tempfile from json import dumps @@ -78,30 +77,13 @@ def create_temp_file(self, trans, files) -> Dict[str, FilesPayload]: ) return files_payload - def create_temp_file_execute(self, trans, files) -> Dict[str, FilesPayload]: - # TODO - could access headers from request maybe better to do that - # header_content = request.headers["content-type"] - files_payload = {} - for i, upload_file in enumerate(files): - with tempfile.NamedTemporaryFile( - dir=trans.app.config.new_file_path, prefix="upload_file_data_", delete=False - ) as dest: - shutil.copyfileobj(upload_file.file, dest) # type: ignore[misc] # https://github.com/python/mypy/issues/15031 - upload_file.file.close() - - # try to grab the name from the header - try: - header_items = upload_file.headers.items()[0] - name_search = re.search(r'name="([^"]+)"', header_items[1]) - if name_search: - name = name_search.group(1) - else: - raise Exception("No name found in header") - # if we can't find the name in the header use the index - except Exception: - name = f"files_{i}|file_data" - files_payload[name] = FilesPayload(filename=upload_file.filename, local_filename=dest.name) - return files_payload + def create_temp_file_execute(self, trans, upload_file) -> FilesPayload: + with tempfile.NamedTemporaryFile( + dir=trans.app.config.new_file_path, prefix="upload_file_data_", delete=False + ) as dest: + shutil.copyfileobj(upload_file.file, dest) # type: ignore[misc] # https://github.com/python/mypy/issues/15031 + upload_file.file.close() + return FilesPayload(filename=upload_file.filename, local_filename=dest.name) def create_fetch( self, @@ -142,7 +124,6 @@ def execute( self, trans: ProvidesHistoryContext, payload: ExecuteToolPayload, - files: Optional[List[UploadFile]] = None, ) -> ToolResponse: tool_id = payload.tool_id tool_uuid = payload.tool_uuid @@ -152,10 +133,14 @@ def execute( ) if tool_id is None and tool_uuid is None: raise HTTPException(status_code=400, detail="Must specify a valid tool_id to use this endpoint.") - create_payload = payload.model_dump(exclude_unset=True) - if files: - create_payload.update(self.create_temp_file_execute(trans, files)) + + # create temporary files from the uploaded files + files = {} + for key in list(create_payload.keys()): + if key.startswith("files_") or key.startswith("__files_"): + files[key] = self.create_temp_file_execute(trans, create_payload.pop(key)) + create_payload.update(files) return self._create(trans, create_payload) def _create(self, trans: ProvidesHistoryContext, payload) -> ToolResponse: