Skip to content

Commit

Permalink
initial
Browse files Browse the repository at this point in the history
  • Loading branch information
pseusys committed Feb 16, 2024
1 parent 6d982e0 commit 311ec86
Show file tree
Hide file tree
Showing 6 changed files with 218 additions and 11 deletions.
2 changes: 2 additions & 0 deletions dff/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
ComponentExecutionState,
GlobalExtraHandlerType,
ExtraHandlerType,
PIPELINE_EXCEPTION_KEY,
PIPELINE_STATE_KEY,
StartConditionCheckerFunction,
StartConditionCheckerAggregationFunction,
Expand All @@ -27,6 +28,7 @@
PipelineBuilder,
)

from .pipeline.actor import LATEST_EXCEPTION_KEY, LATEST_FAILED_NODE_KEY
from .pipeline.pipeline import Pipeline, ACTOR

from .service.extra import BeforeHandler, AfterHandler
Expand Down
66 changes: 58 additions & 8 deletions dff/pipeline/pipeline/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from __future__ import annotations
import logging
import asyncio
from typing import Union, Callable, Optional, Dict, List, TYPE_CHECKING
from typing import Type, Union, Callable, Optional, Dict, List, TYPE_CHECKING
import copy

from dff.utils.turn_caching import cache_clear
Expand All @@ -37,6 +37,10 @@
from dff.script.core.normalization import normalize_label, normalize_response
from dff.script.core.keywords import GLOBAL, LOCAL
from dff.pipeline.service.utils import wrap_sync_function_in_async
from dff.pipeline.types import PIPELINE_EXCEPTION_KEY

LATEST_EXCEPTION_KEY = "LATEST_EXCEPTION"
LATEST_FAILED_NODE_KEY = "LATEST_FAILED_NODE"

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -131,7 +135,7 @@ async def __call__(self, pipeline: Pipeline, ctx: Context):
await self._run_handlers(ctx, pipeline, ActorStage.RUN_PRE_TRANSITIONS_PROCESSING)

# get true labels for scopes (GLOBAL, LOCAL, NODE)
await self._get_true_labels(ctx, pipeline)
await self._get_true_labels(ctx, pipeline, False)
await self._run_handlers(ctx, pipeline, ActorStage.GET_TRUE_LABELS)

# get next node
Expand Down Expand Up @@ -161,6 +165,43 @@ async def __call__(self, pipeline: Pipeline, ctx: Context):

del ctx.framework_states["actor"]

async def process_exception(self, pipeline: Pipeline, ctx: Context):
# context init
self._context_init(ctx)

# get previous node
self._get_previous_node(ctx)

# rewrite previous node
self._rewrite_previous_node(ctx)

# run pre transitions processing
await self._run_pre_transitions_processing(ctx, pipeline)

# get true labels for scopes (GLOBAL, LOCAL, NODE)
await self._get_true_labels(ctx, pipeline, True)

# get next node
self._get_next_node(ctx)

ctx.add_label(ctx.framework_states["actor"]["next_label"][:2])

# rewrite next node
self._rewrite_next_node(ctx)

# run pre response processing
await self._run_pre_response_processing(ctx, pipeline)

# create response
ctx.framework_states["actor"]["response"] = await self.run_response(
ctx.framework_states["actor"]["pre_response_processed_node"].response, ctx, pipeline
)
ctx.add_response(ctx.framework_states["actor"]["response"])

if self._clean_turn_cache:
cache_clear()
del ctx.framework_states["actor"]

@staticmethod
def _context_init(ctx: Optional[Union[Context, dict, str]] = None):
ctx.framework_states["actor"] = {}
Expand All @@ -173,13 +214,13 @@ def _get_previous_node(self, ctx: Context):
ctx.framework_states["actor"]["previous_label"][0], {}
).get(ctx.framework_states["actor"]["previous_label"][1], Node())

async def _get_true_labels(self, ctx: Context, pipeline: Pipeline):
async def _get_true_labels(self, ctx: Context, pipeline: Pipeline, is_exceptional: bool):
# GLOBAL
ctx.framework_states["actor"]["global_transitions"] = (
self.script.get(GLOBAL, {}).get(GLOBAL, Node()).transitions
)
ctx.framework_states["actor"]["global_true_label"] = await self._get_true_label(
ctx.framework_states["actor"]["global_transitions"], ctx, pipeline, GLOBAL, "global"
ctx.framework_states["actor"]["global_transitions"], ctx, pipeline, GLOBAL, is_exceptional, "global"
)

# LOCAL
Expand All @@ -191,6 +232,7 @@ async def _get_true_labels(self, ctx: Context, pipeline: Pipeline):
ctx,
pipeline,
ctx.framework_states["actor"]["previous_label"][0],
is_exceptional,
"local",
)

Expand All @@ -203,6 +245,7 @@ async def _get_true_labels(self, ctx: Context, pipeline: Pipeline):
ctx,
pipeline,
ctx.framework_states["actor"]["previous_label"][0],
is_exceptional,
"node",
)

Expand Down Expand Up @@ -346,14 +389,21 @@ async def _get_true_label(
ctx: Context,
pipeline: Pipeline,
flow_label: LabelType,
is_exceptional: bool,
transition_info: str = "",
) -> Optional[NodeLabel3Type]:
true_labels = []

cond_booleans = await asyncio.gather(
*(self.condition_handler(condition, ctx, pipeline) for condition in transitions.values())
)
for label, cond_is_true in zip(transitions.keys(), cond_booleans):
cond_values = await asyncio.gather(*(self.condition_handler(condition, ctx, pipeline) for condition in transitions.values()))
cond_items = list(zip(transitions.keys(), cond_values))

if is_exceptional:
exception = ctx.framework_states[PIPELINE_EXCEPTION_KEY][LATEST_EXCEPTION_KEY]
cond_items = [(label, isinstance(exception, type(value))) for label, value in cond_items if issubclass(type(value), BaseException)]
else:
cond_items = [(label, value) for label, value in cond_items if isinstance(value, bool)]

for label, cond_is_true in cond_items:
if cond_is_true:
if callable(label):
label = await wrap_sync_function_in_async(label, ctx, pipeline)
Expand Down
11 changes: 8 additions & 3 deletions dff/pipeline/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,21 @@
from typing import Union, List, Dict, Optional, Hashable, Callable

from dff.context_storages import DBContextStorage
from dff.script import Script, Context, ActorStage
from dff.script import NodeLabel2Type, Message
from dff.script import Script, Context, ActorStage, NodeLabel2Type, Message
from dff.utils.turn_caching import cache_clear

from dff.messengers.common import MessengerInterface, CLIMessengerInterface
from ..service.group import ServiceGroup
from ..types import (
ComponentExecutionState,
ServiceBuilder,
ServiceGroupBuilder,
PipelineBuilder,
GlobalExtraHandlerType,
ExtraHandlerFunction,
ExtraHandlerBuilder,
)
from ..types import PIPELINE_STATE_KEY
from ..types import PIPELINE_EXCEPTION_KEY, PIPELINE_STATE_KEY
from .utils import finalize_service_group, pretty_format_component_info_dict
from dff.pipeline.pipeline.actor import Actor

Expand Down Expand Up @@ -343,13 +343,18 @@ async def _run_pipeline(
ctx.misc.update(update_ctx_misc)

ctx.framework_states[PIPELINE_STATE_KEY] = {}
ctx.framework_states[PIPELINE_EXCEPTION_KEY] = {}
ctx.add_request(request)
result = await self._services_pipeline(ctx, self)

if asyncio.iscoroutine(result):
await result

if self._services_pipeline.get_state(ctx) == ComponentExecutionState.FAILED:
await self.actor.process_exception(self, ctx)

del ctx.framework_states[PIPELINE_STATE_KEY]
del ctx.framework_states[PIPELINE_EXCEPTION_KEY]

if isinstance(self.context_storage, DBContextStorage):
await self.context_storage.set_item_async(ctx_id, ctx)
Expand Down
11 changes: 11 additions & 0 deletions dff/pipeline/service/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@

from .utils import wrap_sync_function_in_async, collect_defined_constructor_parameters_to_dict, _get_attrs_with_updates
from ..types import (
PIPELINE_EXCEPTION_KEY,
ServiceBuilder,
StartConditionCheckerFunction,
ComponentExecutionState,
ExtraHandlerBuilder,
ExtraHandlerType,
)
from ..pipeline.actor import LATEST_EXCEPTION_KEY, LATEST_FAILED_NODE_KEY
from ..pipeline.component import PipelineComponent

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -133,6 +135,13 @@ async def _run_as_actor(self, ctx: Context, pipeline: Pipeline) -> None:
await pipeline.actor(pipeline, ctx)
self._set_state(ctx, ComponentExecutionState.FINISHED)
except Exception as exc:
if "actor" in ctx.framework_states:
last_label = ctx.framework_states["actor"]["next_label"]
latest_node = f"{self.name}:{last_label[0]}:{last_label[1]}"
else:
latest_node = self.name
ctx.framework_states[PIPELINE_EXCEPTION_KEY][LATEST_EXCEPTION_KEY] = exc
ctx.framework_states[PIPELINE_EXCEPTION_KEY][LATEST_FAILED_NODE_KEY] = latest_node
self._set_state(ctx, ComponentExecutionState.FAILED)
logger.error(f"Actor '{self.name}' execution failed!\n{exc}")

Expand All @@ -152,6 +161,8 @@ async def _run_as_service(self, ctx: Context, pipeline: Pipeline) -> None:
else:
self._set_state(ctx, ComponentExecutionState.NOT_RUN)
except Exception as e:
ctx.framework_states[PIPELINE_EXCEPTION_KEY][LATEST_EXCEPTION_KEY] = e
ctx.framework_states[PIPELINE_EXCEPTION_KEY][LATEST_FAILED_NODE_KEY] = self.name
self._set_state(ctx, ComponentExecutionState.FAILED)
logger.error(f"Service '{self.name}' execution failed!\n{e}")

Expand Down
3 changes: 3 additions & 0 deletions dff/pipeline/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ class ExtraHandlerType(str, Enum):
AFTER = "AFTER"


PIPELINE_EXCEPTION_KEY = "EXCEPTION"


PIPELINE_STATE_KEY = "PIPELINE"
"""
PIPELINE: storage for services and groups execution status.
Expand Down
136 changes: 136 additions & 0 deletions tutorials/script/core/10_error_conditions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# %% [markdown]
"""
# Core: 10. Error conditions
"""

# %pip install dff

# %%
from typing import Type
from dff.script import GLOBAL, TRANSITIONS, RESPONSE, Context, Message
from dff.pipeline import PIPELINE_EXCEPTION_KEY, LATEST_EXCEPTION_KEY, LATEST_FAILED_NODE_KEY, Pipeline
import dff.script.conditions as cnd
import dff.script.labels as lbl

from dff.utils.testing.common import (
check_happy_path,
is_interactive_mode,
run_interactive_mode,
)


def raise_exception(exception_class: Type[BaseException]) -> Message:
raise exception_class("Some evil cause!")


def print_exception(name: str, _: Pipeline, ctx: Context) -> Message:
exception = ctx.framework_states[PIPELINE_EXCEPTION_KEY].get(LATEST_EXCEPTION_KEY, None)
message = "UNKNOWN" if exception is None else str(exception)
source = ctx.framework_states[PIPELINE_EXCEPTION_KEY].get(LATEST_FAILED_NODE_KEY, None)
return Message(f"Exception type {name} with message '{message}' received from node {source}!")


# %%
toy_script = {
GLOBAL: {
TRANSITIONS: {
("error_flow", "node_name_handler", 1.1): NameError,
("error_flow", "node_buffer_handler", 1.1): BufferError,
},
},
"error_flow": {
"start_node": {
RESPONSE: Message(),
TRANSITIONS: {
"node_start_exceptor": cnd.exact_match(Message("start")),
},
},
"node_start_exceptor": {
RESPONSE: Message("Select an exception to throw!"),
TRANSITIONS: {
"node_name_thrower": cnd.exact_match(Message("name")),
"node_buffer_thrower": cnd.exact_match(Message("buffer")),
"node_file_thrower": cnd.exact_match(Message("fallback")),
},
},
"node_name_thrower": {
RESPONSE: lambda _, __: raise_exception(NameError),
},
"node_buffer_thrower": {
RESPONSE: lambda _, __: raise_exception(BufferError),
},
"node_file_thrower": {
RESPONSE: lambda _, __: raise_exception(FileNotFoundError),
},
"node_name_handler": {
RESPONSE: lambda ctx, pipeline: print_exception("Name Error", pipeline, ctx),
TRANSITIONS: {
"node_start_exceptor": cnd.exact_match(Message("okay...")),
},
},
"node_buffer_handler": {
RESPONSE: lambda ctx, pipeline: print_exception("Buffer Error", pipeline, ctx),
TRANSITIONS: {
"node_start_exceptor": cnd.exact_match(Message("okay...")),
},
},
"fallback_node": {
RESPONSE: Message(f"Unexpected message received or an unknown exception caught!"),
TRANSITIONS: {
"node_start_exceptor": cnd.exact_match(Message("okay...")),
},
},
}
}


happy_path = (
(
Message("start"),
Message("Select an exception to throw!"),
),
(
Message("name"),
Message("Exception type Name Error with message 'Some evil cause!' received from node actor_0:error_flow:node_name_thrower!"),
),
(
Message("okay..."),
Message("Select an exception to throw!"),
),
(
Message("buffer"),
Message("Exception type Buffer Error with message 'Some evil cause!' received from node actor_0:error_flow:node_buffer_thrower!"),
),
(
Message("okay..."),
Message("Select an exception to throw!"),
),
(
Message("fallback"),
Message("Unexpected message received or an unknown exception caught!"),
),
(
Message("okay..."),
Message("Select an exception to throw!"),
),
(
Message("something"),
Message("Unexpected message received or an unknown exception caught!"),
),
)


# %%
pipeline = Pipeline.from_script(
toy_script,
start_label=("error_flow", "start_node"),
fallback_label=("error_flow", "fallback_node"),
validation_stage=False,
)

if __name__ == "__main__":
check_happy_path(pipeline, happy_path)
if is_interactive_mode():
run_interactive_mode(pipeline)

# %%

0 comments on commit 311ec86

Please sign in to comment.