From 772b588f012f5e954cfce623abeca97dc4493493 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 10 Oct 2024 18:11:21 +0800 Subject: [PATCH] refactor(models): enhance file model and tool processing - Add description to file extension field in File model for clarity. - Default meta field to an empty dictionary in ToolInvokeMessage to prevent NoneType errors. - Extend VariableKey enum in Tool with additional categories (DOCUMENT, VIDEO, AUDIO, CUSTOM). - Refactor ToolNode to handle different response types more robustly by refining file ID extraction and managing default extensions. --- api/core/file/models.py | 2 +- api/core/tools/entities/tool_entities.py | 3 ++- api/core/tools/tool/tool.py | 4 +++ api/core/workflow/nodes/tool/tool_node.py | 30 ++++++++++++++++++----- 4 files changed, 31 insertions(+), 8 deletions(-) diff --git a/api/core/file/models.py b/api/core/file/models.py index d0ca70b33e246..c87186970c80b 100644 --- a/api/core/file/models.py +++ b/api/core/file/models.py @@ -43,7 +43,7 @@ class File(BaseModel): remote_url: Optional[str] = None # remote url related_id: Optional[str] = None filename: Optional[str] = None - extension: Optional[str] = None + extension: Optional[str] = Field(default=None, description="File extension, should contains dot") mime_type: Optional[str] = None size: int = -1 _extra_config: FileExtraConfig | None = None diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 53b92f63e5d8b..9a31e673d3052 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -111,7 +111,8 @@ class MessageType(Enum): plain text, image url or link url """ message: str | bytes | dict | None = None - meta: dict[str, Any] | None = None + # TODO: Use a BaseModel for meta + meta: dict[str, Any] = Field(default_factory=dict) save_as: str = "" diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index 0b0be191a3a9a..6cb6e18b6d4e8 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -64,6 +64,10 @@ def __init__(self, **data: Any): class VariableKey(str, Enum): IMAGE = "image" + DOCUMENT = "document" + VIDEO = "video" + AUDIO = "audio" + CUSTOM = "custom" def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool": """ diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 72ec2d9bb9f33..bb7278b70edc3 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -165,17 +165,16 @@ def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) Extract tool response binary """ result = [] - for response in tool_response: if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: url = response.message ext = path.splitext(url)[1] mimetype = response.meta.get("mime_type", "image/jpeg") - filename = response.save_as or url.split("/")[-1] + tool_file_id = response.save_as or url.split("/")[-1] transfer_method = response.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) # get tool file id - tool_file_id = url.split("/")[-1].split(".")[0] + tool_file_id = str(url).split("/")[-1].split(".")[0] result.append( File( tenant_id=self.tenant_id, @@ -183,14 +182,14 @@ def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) transfer_method=transfer_method, remote_url=url, related_id=tool_file_id, - filename=filename, + filename=tool_file_id, extension=ext, mime_type=mimetype, ) ) elif response.type == ToolInvokeMessage.MessageType.BLOB: # get tool file id - tool_file_id = response.message.split("/")[-1].split(".")[0] + tool_file_id = str(response.message).split("/")[-1].split(".")[0] result.append( File( tenant_id=self.tenant_id, @@ -203,7 +202,26 @@ def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) ) ) elif response.type == ToolInvokeMessage.MessageType.LINK: - pass # TODO: + url = str(response.message) + transfer_method = FileTransferMethod.TOOL_FILE + mimetype = response.meta.get("mime_type", "application/octet-stream") + tool_file_id = url.split("/")[-1].split(".")[0] + if "." in url: + extension = "." + url.split("/")[-1].split(".")[1] + else: + extension = ".bin" + file = File( + tenant_id=self.tenant_id, + type=FileType(response.save_as), + transfer_method=transfer_method, + remote_url=url, + filename=tool_file_id, + related_id=tool_file_id, + extension=extension, + mime_type=mimetype, + ) + result.append(file) + elif response.type == ToolInvokeMessage.MessageType.FILE: assert response.meta is not None result.append(response.meta["file"])