Skip to content

Commit

Permalink
framework states rework (#359)
Browse files Browse the repository at this point in the history
# Description

Replace framework_states dict with a pydantic model FrameworkData.

This makes it clear which data is stored in the field as well as allows
using pydantic validation to process framework data.

# Checklist

- [x] I have performed a self-review of the changes

# To Consider

- Add tests (if functionality is changed)
- Update API reference / tutorials / guides
- Update CONTRIBUTING.md (if devel workflow is changed)
- Update `.ignore` files, scripts (such as `lint`), distribution
manifest (if files are added/deleted)
- Search for references to changed entities in the codebase
  • Loading branch information
RLKRo authored Jun 20, 2024
1 parent facc5f7 commit f631e08
Show file tree
Hide file tree
Showing 14 changed files with 119 additions and 143 deletions.
2 changes: 1 addition & 1 deletion dff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
from dff.pipeline import Pipeline
from dff.script import Context, Script

Script.model_rebuild()
import dff.__rebuild_pydantic_models__
9 changes: 9 additions & 0 deletions dff/__rebuild_pydantic_models__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# flake8: noqa: F401

from dff.pipeline import Pipeline
from dff.pipeline.types import ExtraHandlerRuntimeInfo
from dff.script import Context, Script

Script.model_rebuild()
Context.model_rebuild()
ExtraHandlerRuntimeInfo.model_rebuild()
3 changes: 0 additions & 3 deletions dff/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
ComponentExecutionState,
GlobalExtraHandlerType,
ExtraHandlerType,
PIPELINE_STATE_KEY,
StartConditionCheckerFunction,
StartConditionCheckerAggregationFunction,
ExtraHandlerConditionFunction,
Expand All @@ -32,5 +31,3 @@
from .service.extra import BeforeHandler, AfterHandler
from .service.group import ServiceGroup
from .service.service import Service, to_service

ExtraHandlerRuntimeInfo.model_rebuild()
3 changes: 1 addition & 2 deletions dff/pipeline/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from dff.script import Context

from .types import (
PIPELINE_STATE_KEY,
StartConditionCheckerFunction,
ComponentExecutionState,
StartConditionCheckerAggregationFunction,
Expand Down Expand Up @@ -41,7 +40,7 @@ def service_successful_condition(path: Optional[str] = None) -> StartConditionCh
"""

def check_service_state(ctx: Context, _: Pipeline):
state = ctx.framework_states[PIPELINE_STATE_KEY].get(path, ComponentExecutionState.NOT_RUN)
state = ctx.framework_data.service_states.get(path, ComponentExecutionState.NOT_RUN)
return ComponentExecutionState[state] == ComponentExecutionState.FINISHED

return check_service_state
Expand Down
90 changes: 42 additions & 48 deletions dff/pipeline/pipeline/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,6 @@ def __init__(
self._clean_turn_cache = True

async def __call__(self, pipeline: Pipeline, ctx: Context):
# context init
self._context_init(ctx)
await self._run_handlers(ctx, pipeline, ActorStage.CONTEXT_INIT)

# get previous node
Expand All @@ -121,7 +119,7 @@ async def __call__(self, pipeline: Pipeline, ctx: Context):
self._get_next_node(ctx)
await self._run_handlers(ctx, pipeline, ActorStage.GET_NEXT_NODE)

ctx.add_label(ctx.framework_states["actor"]["next_label"][:2])
ctx.add_label(ctx.framework_data.actor_data["next_label"][:2])

# rewrite next node
self._rewrite_next_node(ctx)
Expand All @@ -132,89 +130,85 @@ async def __call__(self, pipeline: Pipeline, ctx: Context):
await self._run_handlers(ctx, pipeline, ActorStage.RUN_PRE_RESPONSE_PROCESSING)

# create response
ctx.framework_states["actor"]["response"] = await self.run_response(
ctx.framework_states["actor"]["pre_response_processed_node"].response, ctx, pipeline
ctx.framework_data.actor_data["response"] = await self.run_response(
ctx.framework_data.actor_data["pre_response_processed_node"].response, ctx, pipeline
)
await self._run_handlers(ctx, pipeline, ActorStage.CREATE_RESPONSE)
ctx.add_response(ctx.framework_states["actor"]["response"])
ctx.add_response(ctx.framework_data.actor_data["response"])

await self._run_handlers(ctx, pipeline, ActorStage.FINISH_TURN)
if self._clean_turn_cache:
cache_clear()

del ctx.framework_states["actor"]

@staticmethod
def _context_init(ctx: Optional[Union[Context, dict, str]] = None):
ctx.framework_states["actor"] = {}
ctx.framework_data.actor_data.clear()

def _get_previous_node(self, ctx: Context):
ctx.framework_states["actor"]["previous_label"] = (
ctx.framework_data.actor_data["previous_label"] = (
normalize_label(ctx.last_label) if ctx.last_label else self.start_label
)
ctx.framework_states["actor"]["previous_node"] = self.script.get(
ctx.framework_states["actor"]["previous_label"][0], {}
).get(ctx.framework_states["actor"]["previous_label"][1], Node())
ctx.framework_data.actor_data["previous_node"] = self.script.get(
ctx.framework_data.actor_data["previous_label"][0], {}
).get(ctx.framework_data.actor_data["previous_label"][1], Node())

async def _get_true_labels(self, ctx: Context, pipeline: Pipeline):
# GLOBAL
ctx.framework_states["actor"]["global_transitions"] = (
ctx.framework_data.actor_data["global_transitions"] = (
self.script.get(GLOBAL, {}).get(GLOBAL, Node()).transitions
)
ctx.framework_states["actor"]["global_true_label"] = await self._get_true_label(
ctx.framework_states["actor"]["global_transitions"], ctx, pipeline, GLOBAL, "global"
ctx.framework_data.actor_data["global_true_label"] = await self._get_true_label(
ctx.framework_data.actor_data["global_transitions"], ctx, pipeline, GLOBAL, "global"
)

# LOCAL
ctx.framework_states["actor"]["local_transitions"] = (
self.script.get(ctx.framework_states["actor"]["previous_label"][0], {}).get(LOCAL, Node()).transitions
ctx.framework_data.actor_data["local_transitions"] = (
self.script.get(ctx.framework_data.actor_data["previous_label"][0], {}).get(LOCAL, Node()).transitions
)
ctx.framework_states["actor"]["local_true_label"] = await self._get_true_label(
ctx.framework_states["actor"]["local_transitions"],
ctx.framework_data.actor_data["local_true_label"] = await self._get_true_label(
ctx.framework_data.actor_data["local_transitions"],
ctx,
pipeline,
ctx.framework_states["actor"]["previous_label"][0],
ctx.framework_data.actor_data["previous_label"][0],
"local",
)

# NODE
ctx.framework_states["actor"]["node_transitions"] = ctx.framework_states["actor"][
ctx.framework_data.actor_data["node_transitions"] = ctx.framework_data.actor_data[
"pre_transitions_processed_node"
].transitions
ctx.framework_states["actor"]["node_true_label"] = await self._get_true_label(
ctx.framework_states["actor"]["node_transitions"],
ctx.framework_data.actor_data["node_true_label"] = await self._get_true_label(
ctx.framework_data.actor_data["node_transitions"],
ctx,
pipeline,
ctx.framework_states["actor"]["previous_label"][0],
ctx.framework_data.actor_data["previous_label"][0],
"node",
)

def _get_next_node(self, ctx: Context):
# choose next label
ctx.framework_states["actor"]["next_label"] = self._choose_label(
ctx.framework_states["actor"]["node_true_label"], ctx.framework_states["actor"]["local_true_label"]
ctx.framework_data.actor_data["next_label"] = self._choose_label(
ctx.framework_data.actor_data["node_true_label"], ctx.framework_data.actor_data["local_true_label"]
)
ctx.framework_states["actor"]["next_label"] = self._choose_label(
ctx.framework_states["actor"]["next_label"], ctx.framework_states["actor"]["global_true_label"]
ctx.framework_data.actor_data["next_label"] = self._choose_label(
ctx.framework_data.actor_data["next_label"], ctx.framework_data.actor_data["global_true_label"]
)
# get next node
ctx.framework_states["actor"]["next_node"] = self.script.get(
ctx.framework_states["actor"]["next_label"][0], {}
).get(ctx.framework_states["actor"]["next_label"][1])
ctx.framework_data.actor_data["next_node"] = self.script.get(
ctx.framework_data.actor_data["next_label"][0], {}
).get(ctx.framework_data.actor_data["next_label"][1])

def _rewrite_previous_node(self, ctx: Context):
node = ctx.framework_states["actor"]["previous_node"]
flow_label = ctx.framework_states["actor"]["previous_label"][0]
ctx.framework_states["actor"]["previous_node"] = self._overwrite_node(
node = ctx.framework_data.actor_data["previous_node"]
flow_label = ctx.framework_data.actor_data["previous_label"][0]
ctx.framework_data.actor_data["previous_node"] = self._overwrite_node(
node,
flow_label,
only_current_node_transitions=True,
)

def _rewrite_next_node(self, ctx: Context):
node = ctx.framework_states["actor"]["next_node"]
flow_label = ctx.framework_states["actor"]["next_label"][0]
ctx.framework_states["actor"]["next_node"] = self._overwrite_node(node, flow_label)
node = ctx.framework_data.actor_data["next_node"]
flow_label = ctx.framework_data.actor_data["next_label"][0]
ctx.framework_data.actor_data["next_node"] = self._overwrite_node(node, flow_label)

def _overwrite_node(
self,
Expand Down Expand Up @@ -290,18 +284,18 @@ async def _run_pre_transitions_processing(self, ctx: Context, pipeline: Pipeline
The execution order depends on the value of the :py:class:`.Pipeline`'s
`parallelize_processing` flag.
"""
ctx.framework_states["actor"]["processed_node"] = copy.deepcopy(ctx.framework_states["actor"]["previous_node"])
pre_transitions_processing = ctx.framework_states["actor"]["previous_node"].pre_transitions_processing
ctx.framework_data.actor_data["processed_node"] = copy.deepcopy(ctx.framework_data.actor_data["previous_node"])
pre_transitions_processing = ctx.framework_data.actor_data["previous_node"].pre_transitions_processing

if pipeline.parallelize_processing:
await self._run_processing_parallel(pre_transitions_processing, ctx, pipeline)
else:
await self._run_processing_sequential(pre_transitions_processing, ctx, pipeline)

ctx.framework_states["actor"]["pre_transitions_processed_node"] = ctx.framework_states["actor"][
ctx.framework_data.actor_data["pre_transitions_processed_node"] = ctx.framework_data.actor_data[
"processed_node"
]
del ctx.framework_states["actor"]["processed_node"]
del ctx.framework_data.actor_data["processed_node"]

async def _run_pre_response_processing(self, ctx: Context, pipeline: Pipeline) -> None:
"""
Expand All @@ -312,16 +306,16 @@ async def _run_pre_response_processing(self, ctx: Context, pipeline: Pipeline) -
The execution order depends on the value of the :py:class:`.Pipeline`'s
`parallelize_processing` flag.
"""
ctx.framework_states["actor"]["processed_node"] = copy.deepcopy(ctx.framework_states["actor"]["next_node"])
pre_response_processing = ctx.framework_states["actor"]["next_node"].pre_response_processing
ctx.framework_data.actor_data["processed_node"] = copy.deepcopy(ctx.framework_data.actor_data["next_node"])
pre_response_processing = ctx.framework_data.actor_data["next_node"].pre_response_processing

if pipeline.parallelize_processing:
await self._run_processing_parallel(pre_response_processing, ctx, pipeline)
else:
await self._run_processing_sequential(pre_response_processing, ctx, pipeline)

ctx.framework_states["actor"]["pre_response_processed_node"] = ctx.framework_states["actor"]["processed_node"]
del ctx.framework_states["actor"]["processed_node"]
ctx.framework_data.actor_data["pre_response_processed_node"] = ctx.framework_data.actor_data["processed_node"]
del ctx.framework_data.actor_data["processed_node"]

async def _get_true_label(
self,
Expand Down
16 changes: 5 additions & 11 deletions dff/pipeline/pipeline/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@
import logging
import abc
import asyncio
import copy
from typing import Optional, Awaitable, TYPE_CHECKING

from dff.script import Context

from ..service.extra import BeforeHandler, AfterHandler
from ..conditions import always_start_condition
from ..types import (
PIPELINE_STATE_KEY,
StartConditionCheckerFunction,
ComponentExecutionState,
ServiceRuntimeInfo,
Expand Down Expand Up @@ -109,28 +107,24 @@ def __init__(

def _set_state(self, ctx: Context, value: ComponentExecutionState):
"""
Method for component runtime state setting, state is preserved in `ctx.framework_states` dict,
in subdict, dedicated to this library.
Method for component runtime state setting, state is preserved in `ctx.framework_data`.
:param ctx: :py:class:`~.Context` to keep state in.
:param value: State to set.
:return: `None`
"""
if PIPELINE_STATE_KEY not in ctx.framework_states:
ctx.framework_states[PIPELINE_STATE_KEY] = {}
ctx.framework_states[PIPELINE_STATE_KEY][self.path] = value
ctx.framework_data.service_states[self.path] = value

def get_state(self, ctx: Context, default: Optional[ComponentExecutionState] = None) -> ComponentExecutionState:
"""
Method for component runtime state getting, state is preserved in `ctx.framework_states` dict,
in subdict, dedicated to this library.
Method for component runtime state getting, state is preserved in `ctx.framework_data`.
:param ctx: :py:class:`~.Context` to get state from.
:param default: Default to return if no record found
(usually it's :py:attr:`~.pipeline.types.ComponentExecutionState.NOT_RUN`).
:return: :py:class:`~pipeline.types.ComponentExecutionState` of this service or default if not found.
"""
return ctx.framework_states[PIPELINE_STATE_KEY].get(self.path, default if default is not None else None)
return ctx.framework_data.service_states.get(self.path, default if default is not None else None)

@property
def asynchronous(self) -> bool:
Expand Down Expand Up @@ -218,7 +212,7 @@ def _get_runtime_info(self, ctx: Context) -> ServiceRuntimeInfo:
path=self.path if self.path is not None else "[None]",
timeout=self.timeout,
asynchronous=self.asynchronous,
execution_state=copy.deepcopy(ctx.framework_states[PIPELINE_STATE_KEY]),
execution_state=ctx.framework_data.service_states.copy(),
)

@property
Expand Down
4 changes: 1 addition & 3 deletions dff/pipeline/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
ExtraHandlerFunction,
ExtraHandlerBuilder,
)
from ..types import PIPELINE_STATE_KEY
from .utils import finalize_service_group, pretty_format_component_info_dict
from dff.pipeline.pipeline.actor import Actor

Expand Down Expand Up @@ -320,14 +319,13 @@ async def _run_pipeline(
if update_ctx_misc is not None:
ctx.misc.update(update_ctx_misc)

ctx.framework_states[PIPELINE_STATE_KEY] = {}
ctx.add_request(request)
result = await self._services_pipeline(ctx, self)

if asyncio.iscoroutine(result):
await result

del ctx.framework_states[PIPELINE_STATE_KEY]
ctx.framework_data.service_states.clear()

if isinstance(self.context_storage, DBContextStorage):
await self.context_storage.set_item_async(ctx_id, ctx)
Expand Down
Loading

0 comments on commit f631e08

Please sign in to comment.