Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: 流程支持从指定位置开始 #180

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion bamboo_engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ def run_pipeline(
cycle_tolerate = options.get("cycle_tolerate", False)
validator.validate_and_process_pipeline(pipeline, cycle_tolerate)

start_node_id = options.get("start_node_id", pipeline["start_event"]["id"])
# 如果起始位置不是开始节点,则需要进行额外校验
validator.validate_pipeline_start_node(pipeline, start_node_id)

self.runtime.pre_prepare_run_pipeline(
pipeline, root_pipeline_data, root_pipeline_context, subprocess_context, **options
)
Expand All @@ -127,7 +131,7 @@ def run_pipeline(
# execute from start event
self.runtime.execute(
process_id=process_id,
node_id=pipeline["start_event"]["id"],
node_id=start_node_id,
root_pipeline_id=pipeline["id"],
parent_pipeline_id=pipeline["id"],
)
Expand Down
5 changes: 5 additions & 0 deletions bamboo_engine/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
specific language governing permissions and limitations under the License.
"""


# 异常定义模块


Expand Down Expand Up @@ -38,6 +39,10 @@ class TreeInvalidException(EngineException):
pass


class StartPositionInvalidException(EngineException):
pass


class ConnectionValidateError(TreeInvalidException):
def __init__(self, failed_nodes, detail, *args):
self.failed_nodes = failed_nodes
Expand Down
6 changes: 5 additions & 1 deletion bamboo_engine/validator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,8 @@
specific language governing permissions and limitations under the License.
"""

from .api import validate_and_process_pipeline # noqa
from .api import ( # noqa
get_allowed_start_node_ids,
validate_and_process_pipeline,
validate_pipeline_start_node,
)
19 changes: 13 additions & 6 deletions bamboo_engine/validator/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,23 @@
specific language governing permissions and limitations under the License.
"""

from bamboo_engine.eri import NodeType
from bamboo_engine import exceptions
from bamboo_engine.eri import NodeType

from . import rules
from .connection import (
validate_graph_connection,
validate_graph_without_circle,
)
from .connection import validate_graph_connection, validate_graph_without_circle
from .gateway import validate_gateways, validate_stream
from .utils import format_pipeline_tree_io_to_list
from .utils import format_pipeline_tree_io_to_list, get_allowed_start_node_ids


def validate_pipeline_start_node(pipeline: dict, node_id: str):
# 当开始位置位于开始节点时,则直接返回
if node_id == pipeline["start_event"]["id"]:
return

allowed_start_node_ids = get_allowed_start_node_ids(pipeline)
if node_id not in allowed_start_node_ids:
raise exceptions.StartPositionInvalidException("this node_id is not allowed as a starting node")


def validate_and_process_pipeline(pipeline: dict, cycle_tolerate=False):
Expand Down
26 changes: 26 additions & 0 deletions bamboo_engine/validator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,29 @@ def get_nodes_dict(data):
node["target"] = [data["flows"][outgoing]["target"] for outgoing in node["outgoing"]]

return nodes


def _compute_pipeline_main_nodes(node_id, node_dict):
hanshuaikang marked this conversation as resolved.
Show resolved Hide resolved
nodes = []
node_detail = node_dict[node_id]
node_type = node_detail["type"]
if node_type in ["EmptyStartEvent", "ServiceActivity"]:
nodes.append(node_id)

if node_type in ["EmptyStartEvent", "ServiceActivity", "ExclusiveGateway", "ConvergeGateway", "SubProcess"]:
next_nodes = node_detail.get("target", [])
for next_node_id in next_nodes:
nodes += _compute_pipeline_main_nodes(next_node_id, node_dict)
elif node_type in ["ParallelGateway", "ConditionalParallelGateway"]:
next_node_id = node_detail["converge_gateway_id"]
nodes += _compute_pipeline_main_nodes(next_node_id, node_dict)

return nodes


def get_allowed_start_node_ids(pipeline_tree):
start_event_id = pipeline_tree["start_event"]["id"]
node_dict = get_nodes_dict(pipeline_tree)
# 流程的开始位置只允许出现在主干,子流程/并行网关内的节点不允许作为起始位置
allowed_start_node_ids = _compute_pipeline_main_nodes(start_event_id, node_dict)
return allowed_start_node_ids
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
import time

import pytest
from pipeline.eri.models import State
from pipeline.eri.runtime import BambooDjangoRuntime

from bamboo_engine.builder import (
EmptyEndEvent,
EmptyStartEvent,
ServiceActivity,
build_tree,
)
from bamboo_engine.engine import Engine


def test_run_pipeline_with_start_node_id():
start = EmptyStartEvent()
act_1 = ServiceActivity(component_code="callback_node")
end = EmptyEndEvent()

start.extend(act_1).extend(end)

pipeline = build_tree(start)
runtime = BambooDjangoRuntime()
engine = Engine(runtime)
engine.run_pipeline(pipeline=pipeline, root_pipeline_data={}, start_node_id=act_1.id)

time.sleep(3)

with pytest.raises(State.DoesNotExist):
# 由于直接跳过了开始节点,此时应该抛异常
runtime.get_state(start.id)

state = runtime.get_state(act_1.id)

assert state.name == "RUNNING"

engine.callback(act_1.id, state.version, {})

time.sleep(2)

state = runtime.get_state(act_1.id)

assert state.name == "FINISHED"

pipeline_state = runtime.get_state(pipeline["id"])

assert pipeline_state.name == "FINISHED"
112 changes: 112 additions & 0 deletions tests/validator/test_validate_start_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# -*- coding: utf-8 -*-
import pytest

from bamboo_engine.builder import (
ConditionalParallelGateway,
ConvergeGateway,
EmptyEndEvent,
EmptyStartEvent,
ExclusiveGateway,
ParallelGateway,
ServiceActivity,
build_tree,
)
from bamboo_engine.exceptions import StartPositionInvalidException
from bamboo_engine.validator import (
get_allowed_start_node_ids,
validate_pipeline_start_node,
)
from bamboo_engine.validator.gateway import validate_gateways


def test_get_allowed_start_node_ids_by_parallel_gateway():
"""
并行网关内的节点将会被忽略
"""
start = EmptyStartEvent()
act_1 = ServiceActivity(component_code="pipe_example_component", name="act_1")
pg = ParallelGateway()
act_2 = ServiceActivity(component_code="pipe_example_component", name="act_2")
act_3 = ServiceActivity(component_code="pipe_example_component", name="act_3")
cg = ConvergeGateway()
end = EmptyEndEvent()
start.extend(act_1).extend(pg).connect(act_2, act_3).to(pg).converge(cg).extend(end)
pipeline = build_tree(start)
# 需要使用 validate_gateways 匹配网关对应的汇聚节点
validate_gateways(pipeline)
allowed_start_node_ids = get_allowed_start_node_ids(pipeline)

assert len(allowed_start_node_ids) == 2
assert allowed_start_node_ids == [start.id, act_1.id]


def test_get_allowed_start_node_ids_by_exclusive_gateway():
start = EmptyStartEvent()
act_1 = ServiceActivity(component_code="pipe_example_component", name="act_1")
eg = ExclusiveGateway(conditions={0: "${act_1_output} < 0", 1: "${act_1_output} >= 0"}, name="act_2 or act_3")
act_2 = ServiceActivity(component_code="pipe_example_component", name="act_2")
act_3 = ServiceActivity(component_code="pipe_example_component", name="act_3")
end = EmptyEndEvent()

start.extend(act_1).extend(eg).connect(act_2, act_3).to(eg).converge(end)
pipeline = build_tree(start)
validate_gateways(pipeline)
allowed_start_node_ids = get_allowed_start_node_ids(pipeline)

assert len(allowed_start_node_ids) == 4
assert allowed_start_node_ids == [start.id, act_1.id, act_2.id, act_3.id]


def test_get_allowed_start_node_ids_by_condition_parallel_gateway():
start = EmptyStartEvent()
act_1 = ServiceActivity(component_code="pipe_example_component", name="act_1")
cpg = ConditionalParallelGateway(
conditions={0: "${act_1_output} < 0", 1: "${act_1_output} >= 0", 2: "${act_1_output} >= 0"},
name="[act_2] or [act_3 and act_4]",
)
act_2 = ServiceActivity(component_code="pipe_example_component", name="act_2")
act_3 = ServiceActivity(component_code="pipe_example_component", name="act_3")
act_4 = ServiceActivity(component_code="pipe_example_component", name="act_4")
cg = ConvergeGateway()
end = EmptyEndEvent()
start.extend(act_1).extend(cpg).connect(act_2, act_3, act_4).to(cpg).converge(cg).extend(end)

pipeline = build_tree(start)
validate_gateways(pipeline)
allowed_start_node_ids = get_allowed_start_node_ids(pipeline)

assert len(allowed_start_node_ids) == 2
assert allowed_start_node_ids == [start.id, act_1.id]


def test_get_allowed_start_node_ids_by_normal():
start = EmptyStartEvent()
act_1 = ServiceActivity(component_code="pipe_example_component", name="act_1")
act_2 = ServiceActivity(component_code="pipe_example_component", name="act_2")
end = EmptyEndEvent()
start.extend(act_1).extend(act_2).extend(end)

pipeline = build_tree(start)
validate_gateways(pipeline)
allowed_start_node_ids = get_allowed_start_node_ids(pipeline)

assert len(allowed_start_node_ids) == 3
assert allowed_start_node_ids == [start.id, act_1.id, act_2.id]


def test_validate_pipeline_start_node():
start = EmptyStartEvent()
act_1 = ServiceActivity(component_code="pipe_example_component", name="act_1")
eg = ExclusiveGateway(conditions={0: "${act_1_output} < 0", 1: "${act_1_output} >= 0"}, name="act_2 or act_3")
act_2 = ServiceActivity(component_code="pipe_example_component", name="act_2")
act_3 = ServiceActivity(component_code="pipe_example_component", name="act_3")
end = EmptyEndEvent()

start.extend(act_1).extend(eg).connect(act_2, act_3).to(eg).converge(end)
pipeline = build_tree(start)
validate_gateways(pipeline)

with pytest.raises(StartPositionInvalidException):
validate_pipeline_start_node(pipeline, eg.id)

validate_pipeline_start_node(pipeline, act_1.id)