Skip to content

Commit

Permalink
feat(workflow): Add ListFilterNode.
Browse files Browse the repository at this point in the history
  • Loading branch information
laipz8200 committed Sep 22, 2024
1 parent 997defa commit 776adfa
Show file tree
Hide file tree
Showing 7 changed files with 265 additions and 2 deletions.
2 changes: 1 addition & 1 deletion api/core/file/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class FileAttribute(str, Enum):
TYPE = "type"
SIZE = "size"
NAME = "name"
MIMETYPE = "mimetype"
MIME_TYPE = "mime_type"
TRANSFER_METHOD = "transfer_method"
URL = "url"

Expand Down
3 changes: 3 additions & 0 deletions api/core/workflow/nodes/list_filter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .node import ListFilterNode

__all__ = ["ListFilterNode"]
37 changes: 37 additions & 0 deletions api/core/workflow/nodes/list_filter/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from collections.abc import Sequence
from typing import Literal

from pydantic import Field

from core.workflow.entities.base_node_data_entities import BaseNodeData

_Condition = Literal[
# string conditions
"contains",
"startswith",
"endswith",
"is",
"in",
"empty",
"not contains",
"not is",
"not in",
"not empty",
# number conditions
"=",
"!=",
"<",
">",
"<=",
">=",
]


class ListFilterNodeData(BaseNodeData):
variable_selector: Sequence[str] = Field(default_factory=list)
order_by: str = ""
order: Literal["asc", "desc"] | None = None
limit: int = -1
key: str = ""
condition: _Condition
value: str
220 changes: 220 additions & 0 deletions api/core/workflow/nodes/list_filter/node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
from collections.abc import Callable
from typing import Literal, cast

from core.file import File, file_manager
from core.helper import ssrf_proxy
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base_node import BaseNode
from enums.workflow_nodes import NodeType
from models.workflow import WorkflowNodeExecutionStatus

from .models import ListFilterNodeData


class ListFilterNode(BaseNode):
_node_data_cls = ListFilterNodeData
_node_type = NodeType.LIST_FILTER

def _run(self):
node_data = cast(ListFilterNodeData, self.node_data)
inputs = {}
process_data = {}
outputs = {}

variable = self.graph_runtime_state.variable_pool.get(node_data.variable_selector)
if variable is None:
error_message = f"Variable not found for selector: {node_data.variable_selector}"
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=error_message)
if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
error_message = (
f"Variable {node_data.variable_selector} is not an ArrayFileSegment, ArrayNumberSegment "
"or ArrayStringSegment"
)
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=error_message)

value = self.graph_runtime_state.variable_pool.convert_template(node_data.value).text

# Filter
if isinstance(variable, ArrayStringSegment):
filter_func = _get_string_filter_func(condition=node_data.condition, value=value)
result = list(filter(filter_func, variable.value))
if node_data.order is not None:
result = _order_string(order=node_data.order, array=result)
elif isinstance(variable, ArrayNumberSegment):
filter_func = _get_number_filter_func(condition=node_data.condition, value=float(value))
result = list(filter(filter_func, variable.value))
if node_data.order is not None:
result = _order_number(order=node_data.order, array=result)
elif isinstance(variable, ArrayFileSegment):
filter_func = _get_file_filter_func(key=node_data.key, condition=node_data.condition, value=value)
result = list(filter(filter_func, variable.value))
if node_data.order is not None:
result = _order_file(order=node_data.order, array=result)

# Slice
if node_data.limit > -1:
result = result[: node_data.limit]

return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs=outputs,
)


def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]:
match key:
case "size":
return _get_file_size
case _:
raise ValueError(f"Invalid key: {key}")


def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
match key:
case "name":
return lambda x: x.filename or ""
case "type":
return lambda x: x.type
case "extension":
return lambda x: x.extension or ""
case "mimetype":
return lambda x: x.mime_type or ""
case "transfer_method":
return lambda x: x.transfer_method
case "urL":
return lambda x: x.url or ""
case _:
raise ValueError(f"Invalid key: {key}")


def _get_file_size(file: File):
if file.related_id:
content = file_manager.download(upload_file_id=file.related_id, tenant_id=file.tenant_id)
return len(content)
elif file.url:
response = ssrf_proxy.head(url=file.url)
response.raise_for_status()
return int(response.headers.get("Content-Length", 0))
else:
raise ValueError("Invalid file")


def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bool]:
match condition:
case "contains":
return _contains(value)
case "startswith":
return _startswith(value)
case "endswith":
return _endswith(value)
case "is":
return _is(value)
case "in":
return _in(value)
case "empty":
return lambda x: x == ""
case "not contains":
return lambda x: not _contains(value)(x)
case "not is":
return lambda x: not _is(value)(x)
case "not in":
return lambda x: not _in(value)(x)
case "not empty":
return lambda x: x != ""
case _:
raise ValueError(f"Invalid condition: {condition}")


def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[int | float], bool]:
match condition:
case "=":
return _eq(value)
case "!=":
return _ne(value)
case "<":
return _lt(value)
case "<=":
return _le(value)
case ">":
return _gt(value)
case ">=":
return _ge(value)
case _:
raise ValueError(f"Invalid condition: {condition}")


def _get_file_filter_func(*, key: str, condition: str, value: str) -> Callable[[File], bool]:
if key in {"name", "type", "extension", "mime_type", "transfer_method", "urL"}:
extract_func = _get_file_extract_string_func(key=key)
return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x))
elif key == "size":
extract_func = _get_file_extract_number_func(key=key)
return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x))
else:
raise ValueError(f"Invalid key: {key}")


def _contains(value: str):
return lambda x: value in x


def _startswith(value: str):
return lambda x: x.startswith(value)


def _endswith(value: str):
return lambda x: x.endswith(value)


def _is(value: str):
return lambda x: x is value


def _in(value: str):
return lambda x: x in value


def _eq(value: int | float):
return lambda x: x == value


def _ne(value: int | float):
return lambda x: x != value


def _lt(value: int | float):
return lambda x: x < value


def _le(value: int | float):
return lambda x: x <= value


def _gt(value: int | float):
return lambda x: x > value


def _ge(value: int | float):
return lambda x: x >= value


def _order_number(*, order: Literal["asc", "desc"], array: list[int | float]):
return sorted(array, key=lambda x: x, reverse=order == "desc")


def _order_string(*, order: Literal["asc", "desc"], array: list[str]):
return sorted(array, key=lambda x: x, reverse=order == "desc")


def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: list[File]):
if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "urL"}:
extract_func = _get_file_extract_string_func(key=order_by)
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
elif order_by == "size":
extract_func = _get_file_extract_number_func(key=order_by)
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
else:
raise ValueError(f"Invalid order key: {order_by}")
2 changes: 2 additions & 0 deletions api/core/workflow/nodes/node_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from core.workflow.nodes.iteration.iteration_node import IterationNode
from core.workflow.nodes.iteration.iteration_start_node import IterationStartNode
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.list_filter import ListFilterNode
from core.workflow.nodes.llm.llm_node import LLMNode
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
Expand Down Expand Up @@ -36,4 +37,5 @@
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode,
NodeType.DOCUMENT_EXTRACTOR: DocumentExtractorNode,
NodeType.LIST_FILTER: ListFilterNode,
}
2 changes: 1 addition & 1 deletion api/core/workflow/utils/condition/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def _get_sub_attribute(*, key: str, variable: FileSegment | ArrayFileSegment) ->
raise ValueError("Invalid file")
case FileAttribute.TYPE:
actual_value = variable.value.type
case FileAttribute.MIMETYPE:
case FileAttribute.MIME_TYPE:
actual_value = variable.value.mime_type
case FileAttribute.TRANSFER_METHOD:
actual_value = variable.value.transfer_method
Expand Down
1 change: 1 addition & 0 deletions api/enums/workflow_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class NodeType(str, Enum):
PARAMETER_EXTRACTOR = "parameter-extractor"
CONVERSATION_VARIABLE_ASSIGNER = "assigner"
DOCUMENT_EXTRACTOR = "document-extractor"
LIST_FILTER = "list-filter"

@classmethod
def value_of(cls, value: str):
Expand Down

0 comments on commit 776adfa

Please sign in to comment.