From 997defa95044a19e52481928fa7b4b17b0412595 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sun, 22 Sep 2024 22:56:56 +0800 Subject: [PATCH] feat(workflow): Support sub condition in if else node. --- api/core/file/__init__.py | 4 +- api/core/file/enums.py | 13 + api/core/workflow/utils/condition/entities.py | 60 +++-- .../workflow/utils/condition/processor.py | 250 +++++++++++------- 4 files changed, 199 insertions(+), 128 deletions(-) diff --git a/api/core/file/__init__.py b/api/core/file/__init__.py index 3c1ab487e2da4..c0f3e2a98d499 100644 --- a/api/core/file/__init__.py +++ b/api/core/file/__init__.py @@ -1,4 +1,4 @@ -from .enums import FileBelongsTo, FileTransferMethod, FileType +from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType from .models import ( File, FileExtraConfig, @@ -12,4 +12,6 @@ "FileBelongsTo", "File", "ImageConfig", + "FileAttribute", + "ArrayFileAttribute", ] diff --git a/api/core/file/enums.py b/api/core/file/enums.py index 9c8a629c15fc8..53785025a9a79 100644 --- a/api/core/file/enums.py +++ b/api/core/file/enums.py @@ -39,3 +39,16 @@ def value_of(value): if member.value == value: return member raise ValueError(f"No matching enum found for value '{value}'") + + +class FileAttribute(str, Enum): + TYPE = "type" + SIZE = "size" + NAME = "name" + MIMETYPE = "mimetype" + TRANSFER_METHOD = "transfer_method" + URL = "url" + + +class ArrayFileAttribute(str, Enum): + LENGTH = "length" diff --git a/api/core/workflow/utils/condition/entities.py b/api/core/workflow/utils/condition/entities.py index b8e8b881a5595..e518786b73524 100644 --- a/api/core/workflow/utils/condition/entities.py +++ b/api/core/workflow/utils/condition/entities.py @@ -1,32 +1,42 @@ from typing import Literal, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field +SupportedComparisonOperator = Literal[ + # for string or array + "contains", + "not contains", + "start with", + "end with", + "is", + "is not", + "empty", + "not empty", + # for number + "=", + "≠", + ">", + "<", + "≥", + "≤", + "null", + "not null", +] + + +class SubCondition(BaseModel): + key: str + comparison_operator: SupportedComparisonOperator + value: Optional[str] = None -class Condition(BaseModel): - """ - Condition entity - """ +class SubVariable(BaseModel): + logical_operator: Literal["and", "or"] + conditions: list[SubCondition] = Field(default=list) + + +class Condition(BaseModel): variable_selector: list[str] - comparison_operator: Literal[ - # for string or array - "contains", - "not contains", - "start with", - "end with", - "is", - "is not", - "empty", - "not empty", - # for number - "=", - "≠", - ">", - "<", - "≥", - "≤", - "null", - "not null", - ] + comparison_operator: SupportedComparisonOperator value: Optional[str] = None + sub_variable: SubVariable | None = None diff --git a/api/core/workflow/utils/condition/processor.py b/api/core/workflow/utils/condition/processor.py index c40f8efa8db4d..f98f2ef8fc477 100644 --- a/api/core/workflow/utils/condition/processor.py +++ b/api/core/workflow/utils/condition/processor.py @@ -1,100 +1,112 @@ from collections.abc import Sequence -from typing import Any, Optional +from typing import Any -from core.file.models import File +from core.file import ArrayFileAttribute, FileAttribute, file_manager +from core.helper import ssrf_proxy +from core.variables import FileSegment +from core.variables.segments import ArrayFileSegment from core.workflow.entities.variable_pool import VariablePool -from core.workflow.utils.condition.entities import Condition -from core.workflow.utils.variable_template_parser import VariableTemplateParser + +from .entities import Condition, SupportedComparisonOperator class ConditionProcessor: def process_conditions(self, variable_pool: VariablePool, conditions: Sequence[Condition]): input_conditions = [] - group_result = [] + group_results = [] - index = 0 for condition in conditions: - index += 1 - actual_value = variable_pool.get_any(condition.variable_selector) - - expected_value = None - if condition.value is not None: - variable_template_parser = VariableTemplateParser(template=condition.value) - variable_selectors = variable_template_parser.extract_variable_selectors() - if variable_selectors: - for variable_selector in variable_selectors: - value = variable_pool.get_any(variable_selector.value_selector) - expected_value = variable_template_parser.format({variable_selector.variable: value}) - - if expected_value is None: - expected_value = condition.value - else: - expected_value = condition.value - - comparison_operator = condition.comparison_operator - input_conditions.append( - { - "actual_value": actual_value, - "expected_value": expected_value, - "comparison_operator": comparison_operator, - } - ) - - result = self.evaluate_condition(actual_value, comparison_operator, expected_value) - group_result.append(result) - - return input_conditions, group_result + variable = variable_pool.get(condition.variable_selector) + + if condition.sub_variable: + if not isinstance(variable, FileSegment | ArrayFileSegment): + raise ValueError("Invalid actual value type: FileSegment or ArrayFileSegment") + for sub_condition in condition.sub_variable.conditions: + sub_group_results = [] + actual_value = _get_sub_attribute(key=sub_condition.key, variable=variable) + expected_value = sub_condition.value + expected_value = variable_pool.convert_template(expected_value).text if expected_value else None + sub_result = self.evaluate_condition( + actual_value=actual_value, + operator=sub_condition.comparison_operator, + expected_value=expected_value, + ) + input_conditions.append( + { + "actual_value": actual_value, + "expected_value": expected_value, + "comparison_operator": sub_condition.comparison_operator, + } + ) + sub_group_results.append(sub_result) + result = ( + all(sub_group_results) + if condition.sub_variable.logical_operator == "and" + else any(sub_group_results) + ) + else: + actual_value = variable.value if variable else None + expected_value = condition.value + expected_value = variable_pool.convert_template(expected_value).text if expected_value else None + input_conditions.append( + { + "actual_value": actual_value, + "expected_value": expected_value, + "comparison_operator": condition.comparison_operator, + } + ) + result = self.evaluate_condition( + actual_value=actual_value, + operator=condition.comparison_operator, + expected_value=expected_value, + ) + group_results.append(result) + + return input_conditions, group_results def evaluate_condition( self, - actual_value: Optional[str | int | float | dict[Any, Any] | list[Any] | File | None], - comparison_operator: str, - expected_value: Optional[str] = None, + actual_value: Any, + operator: SupportedComparisonOperator, + expected_value: str | None, ) -> bool: - """ - Evaluate condition - :param actual_value: actual value - :param expected_value: expected value - :param comparison_operator: comparison operator - - :return: bool - """ - if comparison_operator == "contains": - return self._assert_contains(actual_value, expected_value) - elif comparison_operator == "not contains": - return self._assert_not_contains(actual_value, expected_value) - elif comparison_operator == "start with": - return self._assert_start_with(actual_value, expected_value) - elif comparison_operator == "end with": - return self._assert_end_with(actual_value, expected_value) - elif comparison_operator == "is": - return self._assert_is(actual_value, expected_value) - elif comparison_operator == "is not": - return self._assert_is_not(actual_value, expected_value) - elif comparison_operator == "empty": - return self._assert_empty(actual_value) - elif comparison_operator == "not empty": - return self._assert_not_empty(actual_value) - elif comparison_operator == "=": - return self._assert_equal(actual_value, expected_value) - elif comparison_operator == "≠": - return self._assert_not_equal(actual_value, expected_value) - elif comparison_operator == ">": - return self._assert_greater_than(actual_value, expected_value) - elif comparison_operator == "<": - return self._assert_less_than(actual_value, expected_value) - elif comparison_operator == "≥": - return self._assert_greater_than_or_equal(actual_value, expected_value) - elif comparison_operator == "≤": - return self._assert_less_than_or_equal(actual_value, expected_value) - elif comparison_operator == "null": - return self._assert_null(actual_value) - elif comparison_operator == "not null": - return self._assert_not_null(actual_value) - else: - raise ValueError(f"Invalid comparison operator: {comparison_operator}") - - def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: + match operator: + case "contains": + return self._assert_contains(actual_value=actual_value, expected_value=expected_value) + case "not contains": + return self._assert_not_contains(actual_value=actual_value, expected_value=expected_value) + case "start with": + return self._assert_start_with(actual_value=actual_value, expected_value=expected_value) + case "end with": + return self._assert_end_with(actual_value=actual_value, expected_value=expected_value) + case "is": + return self._assert_is(actual_value=actual_value, expected_value=expected_value) + case "is not": + return self._assert_is_not(actual_value=actual_value, expected_value=expected_value) + case "empty": + return self._assert_empty(actual_value=actual_value) + case "not empty": + return self._assert_not_empty(actual_value=actual_value) + case "=": + return self._assert_equal(actual_value=actual_value, expected_value=expected_value) + case "≠": + return self._assert_not_equal(actual_value=actual_value, expected_value=expected_value) + case ">": + return self._assert_greater_than(actual_value=actual_value, expected_value=expected_value) + case "<": + return self._assert_less_than(actual_value=actual_value, expected_value=expected_value) + case "≥": + return self._assert_greater_than_or_equal(actual_value=actual_value, expected_value=expected_value) + case "≤": + return self._assert_less_than_or_equal(actual_value=actual_value, expected_value=expected_value) + case "null": + return self._assert_null(actual_value=actual_value) + case "not null": + return self._assert_not_null(actual_value=actual_value) + case _: + raise ValueError(f"Unsupported operator: {operator}") + + def _assert_contains(self, actual_value: Any, expected_value: Any) -> bool: """ Assert contains :param actual_value: actual value @@ -111,7 +123,7 @@ def _assert_contains(self, actual_value: Optional[str | list], expected_value: s return False return True - def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: + def _assert_not_contains(self, actual_value: Any, expected_value: Any) -> bool: """ Assert not contains :param actual_value: actual value @@ -128,7 +140,7 @@ def _assert_not_contains(self, actual_value: Optional[str | list], expected_valu return False return True - def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool: + def _assert_start_with(self, actual_value: Any, expected_value: Any) -> bool: """ Assert start with :param actual_value: actual value @@ -145,7 +157,7 @@ def _assert_start_with(self, actual_value: Optional[str], expected_value: str) - return False return True - def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool: + def _assert_end_with(self, actual_value: Any, expected_value: Any) -> bool: """ Assert end with :param actual_value: actual value @@ -162,7 +174,7 @@ def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> return False return True - def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool: + def _assert_is(self, actual_value: Any, expected_value: Any) -> bool: """ Assert is :param actual_value: actual value @@ -179,7 +191,7 @@ def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool: return False return True - def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool: + def _assert_is_not(self, actual_value: Any, expected_value: Any) -> bool: """ Assert is not :param actual_value: actual value @@ -196,7 +208,7 @@ def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bo return False return True - def _assert_empty(self, actual_value: Optional[str]) -> bool: + def _assert_empty(self, actual_value: Any) -> bool: """ Assert empty :param actual_value: actual value @@ -206,7 +218,7 @@ def _assert_empty(self, actual_value: Optional[str]) -> bool: return True return False - def _assert_not_empty(self, actual_value: Optional[str]) -> bool: + def _assert_not_empty(self, actual_value: Any) -> bool: """ Assert not empty :param actual_value: actual value @@ -216,7 +228,7 @@ def _assert_not_empty(self, actual_value: Optional[str]) -> bool: return True return False - def _assert_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: + def _assert_equal(self, actual_value: Any, expected_value: Any) -> bool: """ Assert equal :param actual_value: actual value @@ -238,7 +250,7 @@ def _assert_equal(self, actual_value: Optional[int | float], expected_value: str return False return True - def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: + def _assert_not_equal(self, actual_value: Any, expected_value: Any) -> bool: """ Assert not equal :param actual_value: actual value @@ -260,7 +272,7 @@ def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: return False return True - def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: + def _assert_greater_than(self, actual_value: Any, expected_value: Any) -> bool: """ Assert greater than :param actual_value: actual value @@ -282,7 +294,7 @@ def _assert_greater_than(self, actual_value: Optional[int | float], expected_val return False return True - def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: + def _assert_less_than(self, actual_value: Any, expected_value: Any) -> bool: """ Assert less than :param actual_value: actual value @@ -304,9 +316,7 @@ def _assert_less_than(self, actual_value: Optional[int | float], expected_value: return False return True - def _assert_greater_than_or_equal( - self, actual_value: Optional[int | float], expected_value: str | int | float - ) -> bool: + def _assert_greater_than_or_equal(self, actual_value: Any, expected_value: Any) -> bool: """ Assert greater than or equal :param actual_value: actual value @@ -328,9 +338,7 @@ def _assert_greater_than_or_equal( return False return True - def _assert_less_than_or_equal( - self, actual_value: Optional[int | float], expected_value: str | int | float - ) -> bool: + def _assert_less_than_or_equal(self, actual_value: Any, expected_value: Any) -> bool: """ Assert less than or equal :param actual_value: actual value @@ -352,7 +360,7 @@ def _assert_less_than_or_equal( return False return True - def _assert_null(self, actual_value: Optional[int | float]) -> bool: + def _assert_null(self, actual_value: Any) -> bool: """ Assert null :param actual_value: actual value @@ -362,7 +370,7 @@ def _assert_null(self, actual_value: Optional[int | float]) -> bool: return True return False - def _assert_not_null(self, actual_value: Optional[int | float]) -> bool: + def _assert_not_null(self, actual_value: Any) -> bool: """ Assert not null :param actual_value: actual value @@ -379,3 +387,41 @@ def __init__(self, message: str, conditions: list[dict], sub_condition_compare_r self.conditions = conditions self.sub_condition_compare_results = sub_condition_compare_results super().__init__(self.message) + + +def _get_sub_attribute(*, key: str, variable: FileSegment | ArrayFileSegment) -> Any: + if isinstance(variable, FileSegment): + attribute = FileAttribute(key) + match attribute: + case FileAttribute.NAME: + actual_value = variable.value.filename + case FileAttribute.SIZE: + file = variable.value + if file.related_id: + file_contnet = file_manager.download(upload_file_id=file.related_id, tenant_id=file.tenant_id) + actual_value = len(file_contnet) + elif file.url: + response = ssrf_proxy.head(url=file.url) + response.raise_for_status() + actual_value = int(response.headers.get("Content-Length", 0)) + else: + raise ValueError("Invalid file") + case FileAttribute.TYPE: + actual_value = variable.value.type + case FileAttribute.MIMETYPE: + actual_value = variable.value.mime_type + case FileAttribute.TRANSFER_METHOD: + actual_value = variable.value.transfer_method + case FileAttribute.URL: + actual_value = variable.value.url + case _: + raise ValueError(f"Invalid file attribute: {attribute}") + elif isinstance(variable, ArrayFileSegment): + attribute = ArrayFileAttribute(key) + match attribute: + case ArrayFileAttribute.LENGTH: + actual_value = len(variable.value) + case _: + raise ValueError(f"Invalid array file attribute: {attribute}") + + return actual_value