diff --git a/dff/messengers/common/__init__.py b/dff/messengers/common/__init__.py index ceac90c63..d9c66d921 100644 --- a/dff/messengers/common/__init__.py +++ b/dff/messengers/common/__init__.py @@ -1,4 +1,4 @@ # -*- coding: utf-8 -*- from .interface import MessengerInterface, PollingMessengerInterface, CallbackMessengerInterface, CLIMessengerInterface -from .types import PipelineRunnerFunction, PollingInterfaceLoopFunction +from .types import PollingInterfaceLoopFunction diff --git a/dff/messengers/common/interface.py b/dff/messengers/common/interface.py index 96ffc288b..13747b06c 100644 --- a/dff/messengers/common/interface.py +++ b/dff/messengers/common/interface.py @@ -4,15 +4,18 @@ The Message Interfaces module contains several basic classes that define the message interfaces. These classes provide a way to define the structure of the messengers that are used to communicate with the DFF. """ +from __future__ import annotations import abc import asyncio import logging import uuid -from typing import Optional, Any, List, Tuple, TextIO, Hashable +from typing import Optional, Any, List, Tuple, TextIO, Hashable, TYPE_CHECKING from dff.script import Context, Message +from dff.messengers.common.types import PollingInterfaceLoopFunction -from .types import PipelineRunnerFunction, PollingInterfaceLoopFunction +if TYPE_CHECKING: + from dff.pipeline.types import PipelineRunnerFunction logger = logging.getLogger(__name__) @@ -29,9 +32,8 @@ async def connect(self, pipeline_runner: PipelineRunnerFunction): Method invoked when message interface is instantiated and connection is established. May be used for sending an introduction message or displaying general bot information. - :param pipeline_runner: A function that should return pipeline response to user request; + :param pipeline_runner: A function that should process user request and return context; usually it's a :py:meth:`~dff.pipeline.pipeline.pipeline.Pipeline._run_pipeline` function. - :type pipeline_runner: PipelineRunnerFunction """ raise NotImplementedError @@ -96,12 +98,10 @@ async def connect( The looping behavior is determined by `loop` and `timeout`, for most cases the loop itself shouldn't be overridden. - :param pipeline_runner: A function that should return pipeline response to user request; + :param pipeline_runner: A function that should process user request and return context; usually it's a :py:meth:`~dff.pipeline.pipeline.pipeline.Pipeline._run_pipeline` function. - :type pipeline_runner: PipelineRunnerFunction :param loop: a function that determines whether polling should be continued; called in each cycle, should return `True` to continue polling or `False` to stop. - :type loop: PollingInterfaceLoopFunction :param timeout: a time interval between polls (in seconds). """ while loop(): @@ -124,33 +124,23 @@ def __init__(self): async def connect(self, pipeline_runner: PipelineRunnerFunction): self._pipeline_runner = pipeline_runner - async def on_request_async(self, request: Any, ctx_id: Hashable) -> Context: + async def on_request_async( + self, request: Message, ctx_id: Optional[Hashable] = None, update_ctx_misc: Optional[dict] = None + ) -> Context: """ - Method invoked on user input. This method works just like - :py:meth:`~dff.pipeline.pipeline.pipeline.Pipeline._run_pipeline`, - however callback message interface may contain additional functionality (e.g. for external API accessing). - Return context that represents dialog with the user; - `last_response`, `id` and some dialog info can be extracted from there. - - :param request: User input. - :param ctx_id: Any unique id that will be associated with dialog between this user and pipeline. - :return: Context that represents dialog with the user. + Method that should be invoked on user input. + This method has the same signature as :py:class:`~dff.pipeline.types.PipelineRunnerFunction`. """ - return await self._pipeline_runner(request, ctx_id) + return await self._pipeline_runner(request, ctx_id, update_ctx_misc) - def on_request(self, request: Any, ctx_id: Hashable) -> Context: + def on_request( + self, request: Any, ctx_id: Optional[Hashable] = None, update_ctx_misc: Optional[dict] = None + ) -> Context: """ - Method invoked on user input. This method works just like - :py:meth:`~dff.pipeline.pipeline.pipeline.Pipeline._run_pipeline`, - however callback message interface may contain additional functionality (e.g. for external API accessing). - Return context that represents dialog with the user; - `last_response`, `id` and some dialog info can be extracted from there. - - :param request: User input. - :param ctx_id: Any unique id that will be associated with dialog between this user and pipeline. - :return: Context that represents dialog with the user. + Method that should be invoked on user input. + This method has the same signature as :py:class:`~dff.pipeline.types.PipelineRunnerFunction`. """ - return asyncio.run(self.on_request_async(request, ctx_id)) + return asyncio.run(self.on_request_async(request, ctx_id, update_ctx_misc)) class CLIMessengerInterface(PollingMessengerInterface): @@ -183,9 +173,8 @@ async def connect(self, pipeline_runner: PipelineRunnerFunction, **kwargs): """ The CLIProvider generates new dialog id used to user identification on each `connect` call. - :param pipeline_runner: A function that should return pipeline response to user request; + :param pipeline_runner: A function that should process user request and return context; usually it's a :py:meth:`~dff.pipeline.pipeline.pipeline.Pipeline._run_pipeline` function. - :type pipeline_runner: PipelineRunnerFunction :param \\**kwargs: argument, added for compatibility with super class, it shouldn't be used normally. """ self._ctx_id = uuid.uuid4() diff --git a/dff/messengers/common/types.py b/dff/messengers/common/types.py index 7c9275a93..e43769551 100644 --- a/dff/messengers/common/types.py +++ b/dff/messengers/common/types.py @@ -1,22 +1,11 @@ """ Types ----- -The Types module contains two special types that are used throughout the `DFF Messengers`. -The first type is used for the messenger interface to client interaction and the second one -to control the polling loop. +The Types module contains special types that are used throughout the `DFF Messengers`. """ -from typing import Callable, Any, Hashable, Awaitable +from typing import Callable from typing_extensions import TypeAlias -from dff.script import Context - - -PipelineRunnerFunction: TypeAlias = Callable[[Any, Hashable], Awaitable[Context]] -""" -A function type for messenger_interface-to-client interaction. -Accepts anything (user input) and hashable value (current context id), returns string (answer from pipeline). -""" - PollingInterfaceLoopFunction: TypeAlias = Callable[[], bool] """ diff --git a/dff/messengers/telegram/interface.py b/dff/messengers/telegram/interface.py index 5d8f1a902..ba482f01b 100644 --- a/dff/messengers/telegram/interface.py +++ b/dff/messengers/telegram/interface.py @@ -9,7 +9,8 @@ from telebot import types, apihelper -from dff.messengers.common import MessengerInterface, PipelineRunnerFunction, CallbackMessengerInterface +from dff.messengers.common import MessengerInterface, CallbackMessengerInterface +from dff.pipeline.types import PipelineRunnerFunction from .messenger import TelegramMessenger from .message import TelegramMessage diff --git a/dff/pipeline/pipeline/actor.py b/dff/pipeline/pipeline/actor.py index 1bd5878b0..c567c554a 100644 --- a/dff/pipeline/pipeline/actor.py +++ b/dff/pipeline/pipeline/actor.py @@ -23,7 +23,8 @@ .. figure:: /_static/drawio/dfe/user_actor.png """ import logging -from typing import Union, Callable, Optional, Dict, List, Any, ForwardRef +import asyncio +from typing import Union, Callable, Optional, Dict, List, ForwardRef import copy from dff.utils.turn_caching import cache_clear @@ -34,6 +35,7 @@ from dff.script.core.script import Script, Node from dff.script.core.normalization import normalize_label, normalize_response from dff.script.core.keywords import GLOBAL, LOCAL +from dff.pipeline.service.utils import wrap_sync_function_in_async logger = logging.getLogger(__name__) @@ -109,80 +111,72 @@ def __init__( # NB! The following API is highly experimental and may be removed at ANY time WITHOUT FURTHER NOTICE!! self._clean_turn_cache = True - def __call__( - self, pipeline: Pipeline, ctx: Optional[Union[Context, dict, str]] = None, *args, **kwargs - ) -> Union[Context, dict, str]: + async def __call__(self, pipeline: Pipeline, ctx: Context): # context init - ctx = self._context_init(ctx, *args, **kwargs) - self._run_handlers(ctx, pipeline, ActorStage.CONTEXT_INIT, *args, **kwargs) + self._context_init(ctx) + await self._run_handlers(ctx, pipeline, ActorStage.CONTEXT_INIT) # get previous node - ctx = self._get_previous_node(ctx, *args, **kwargs) - self._run_handlers(ctx, pipeline, ActorStage.GET_PREVIOUS_NODE, *args, **kwargs) + self._get_previous_node(ctx) + await self._run_handlers(ctx, pipeline, ActorStage.GET_PREVIOUS_NODE) # rewrite previous node - ctx = self._rewrite_previous_node(ctx, *args, **kwargs) - self._run_handlers(ctx, pipeline, ActorStage.REWRITE_PREVIOUS_NODE, *args, **kwargs) + self._rewrite_previous_node(ctx) + await self._run_handlers(ctx, pipeline, ActorStage.REWRITE_PREVIOUS_NODE) # run pre transitions processing - ctx = self._run_pre_transitions_processing(ctx, pipeline, *args, **kwargs) - self._run_handlers(ctx, pipeline, ActorStage.RUN_PRE_TRANSITIONS_PROCESSING, *args, **kwargs) + await self._run_pre_transitions_processing(ctx, pipeline) + await self._run_handlers(ctx, pipeline, ActorStage.RUN_PRE_TRANSITIONS_PROCESSING) # get true labels for scopes (GLOBAL, LOCAL, NODE) - ctx = self._get_true_labels(ctx, pipeline, *args, **kwargs) - self._run_handlers(ctx, pipeline, ActorStage.GET_TRUE_LABELS, *args, **kwargs) + await self._get_true_labels(ctx, pipeline) + await self._run_handlers(ctx, pipeline, ActorStage.GET_TRUE_LABELS) # get next node - ctx = self._get_next_node(ctx, *args, **kwargs) - self._run_handlers(ctx, pipeline, ActorStage.GET_NEXT_NODE, *args, **kwargs) + self._get_next_node(ctx) + await self._run_handlers(ctx, pipeline, ActorStage.GET_NEXT_NODE) ctx.add_label(ctx.framework_states["actor"]["next_label"][:2]) # rewrite next node - ctx = self._rewrite_next_node(ctx, *args, **kwargs) - self._run_handlers(ctx, pipeline, ActorStage.REWRITE_NEXT_NODE, *args, **kwargs) + self._rewrite_next_node(ctx) + await self._run_handlers(ctx, pipeline, ActorStage.REWRITE_NEXT_NODE) # run pre response processing - ctx = self._run_pre_response_processing(ctx, pipeline, *args, **kwargs) - self._run_handlers(ctx, pipeline, ActorStage.RUN_PRE_RESPONSE_PROCESSING, *args, **kwargs) + await self._run_pre_response_processing(ctx, pipeline) + await self._run_handlers(ctx, pipeline, ActorStage.RUN_PRE_RESPONSE_PROCESSING) # create response - ctx.framework_states["actor"]["response"] = ctx.framework_states["actor"][ - "pre_response_processed_node" - ].run_response(ctx, pipeline, *args, **kwargs) - self._run_handlers(ctx, pipeline, ActorStage.CREATE_RESPONSE, *args, **kwargs) + ctx.framework_states["actor"]["response"] = await self.run_response( + ctx.framework_states["actor"]["pre_response_processed_node"].response, ctx, pipeline + ) + await self._run_handlers(ctx, pipeline, ActorStage.CREATE_RESPONSE) ctx.add_response(ctx.framework_states["actor"]["response"]) - self._run_handlers(ctx, pipeline, ActorStage.FINISH_TURN, *args, **kwargs) + await self._run_handlers(ctx, pipeline, ActorStage.FINISH_TURN) if self._clean_turn_cache: cache_clear() del ctx.framework_states["actor"] - return ctx @staticmethod - def _context_init(ctx: Optional[Union[Context, dict, str]] = None, *args, **kwargs) -> Context: - ctx = Context.cast(ctx) - if not ctx.requests: - ctx.add_request(Message()) + def _context_init(ctx: Optional[Union[Context, dict, str]] = None): ctx.framework_states["actor"] = {} - return ctx - def _get_previous_node(self, ctx: Context, *args, **kwargs) -> Context: + def _get_previous_node(self, ctx: Context): ctx.framework_states["actor"]["previous_label"] = ( normalize_label(ctx.last_label) if ctx.last_label else self.start_label ) ctx.framework_states["actor"]["previous_node"] = self.script.get( ctx.framework_states["actor"]["previous_label"][0], {} ).get(ctx.framework_states["actor"]["previous_label"][1], Node()) - return ctx - def _get_true_labels(self, ctx: Context, pipeline: Pipeline, *args, **kwargs) -> Context: + async def _get_true_labels(self, ctx: Context, pipeline: Pipeline): # GLOBAL ctx.framework_states["actor"]["global_transitions"] = ( self.script.get(GLOBAL, {}).get(GLOBAL, Node()).transitions ) - ctx.framework_states["actor"]["global_true_label"] = self._get_true_label( + ctx.framework_states["actor"]["global_true_label"] = await self._get_true_label( ctx.framework_states["actor"]["global_transitions"], ctx, pipeline, GLOBAL, "global" ) @@ -190,7 +184,7 @@ def _get_true_labels(self, ctx: Context, pipeline: Pipeline, *args, **kwargs) -> ctx.framework_states["actor"]["local_transitions"] = ( self.script.get(ctx.framework_states["actor"]["previous_label"][0], {}).get(LOCAL, Node()).transitions ) - ctx.framework_states["actor"]["local_true_label"] = self._get_true_label( + ctx.framework_states["actor"]["local_true_label"] = await self._get_true_label( ctx.framework_states["actor"]["local_transitions"], ctx, pipeline, @@ -202,16 +196,15 @@ def _get_true_labels(self, ctx: Context, pipeline: Pipeline, *args, **kwargs) -> ctx.framework_states["actor"]["node_transitions"] = ctx.framework_states["actor"][ "pre_transitions_processed_node" ].transitions - ctx.framework_states["actor"]["node_true_label"] = self._get_true_label( + ctx.framework_states["actor"]["node_true_label"] = await self._get_true_label( ctx.framework_states["actor"]["node_transitions"], ctx, pipeline, ctx.framework_states["actor"]["previous_label"][0], "node", ) - return ctx - def _get_next_node(self, ctx: Context, *args, **kwargs) -> Context: + def _get_next_node(self, ctx: Context): # choose next label ctx.framework_states["actor"]["next_label"] = self._choose_label( ctx.framework_states["actor"]["node_true_label"], ctx.framework_states["actor"]["local_true_label"] @@ -223,9 +216,8 @@ def _get_next_node(self, ctx: Context, *args, **kwargs) -> Context: ctx.framework_states["actor"]["next_node"] = self.script.get( ctx.framework_states["actor"]["next_label"][0], {} ).get(ctx.framework_states["actor"]["next_label"][1]) - return ctx - def _rewrite_previous_node(self, ctx: Context, *args, **kwargs) -> Context: + def _rewrite_previous_node(self, ctx: Context): node = ctx.framework_states["actor"]["previous_node"] flow_label = ctx.framework_states["actor"]["previous_label"][0] ctx.framework_states["actor"]["previous_node"] = self._overwrite_node( @@ -233,21 +225,17 @@ def _rewrite_previous_node(self, ctx: Context, *args, **kwargs) -> Context: flow_label, only_current_node_transitions=True, ) - return ctx - def _rewrite_next_node(self, ctx: Context, *args, **kwargs) -> Context: + def _rewrite_next_node(self, ctx: Context): node = ctx.framework_states["actor"]["next_node"] flow_label = ctx.framework_states["actor"]["next_label"][0] ctx.framework_states["actor"]["next_node"] = self._overwrite_node(node, flow_label) - return ctx def _overwrite_node( self, current_node: Node, flow_label: LabelType, - *args, only_current_node_transitions: bool = False, - **kwargs, ) -> Node: overwritten_node = copy.deepcopy(self.script.get(GLOBAL, {}).get(GLOBAL, Node())) local_node = self.script.get(flow_label, {}).get(LOCAL, Node()) @@ -262,43 +250,114 @@ def _overwrite_node( overwritten_node.transitions = current_node.transitions return overwritten_node - def _run_pre_transitions_processing(self, ctx: Context, pipeline: Pipeline, *args, **kwargs) -> Context: - ctx.framework_states["actor"]["processed_node"] = copy.deepcopy(ctx.framework_states["actor"]["previous_node"]) - ctx = ctx.framework_states["actor"]["previous_node"].run_pre_transitions_processing( - ctx, pipeline, *args, **kwargs + async def run_response( + self, + response: Optional[Union[Message, Callable[..., Message]]], + ctx: Context, + pipeline: Pipeline, + ) -> Message: + """ + Executes the normalized response as an asynchronous function. + See the details in the :py:func:`~normalize_response` function of `normalization.py`. + """ + response = normalize_response(response) + return await wrap_sync_function_in_async(response, ctx, pipeline) + + async def _run_processing_parallel(self, processing: dict, ctx: Context, pipeline: Pipeline) -> None: + """ + Execute the processing functions for a particular node simultaneously, + independent of the order. + + Picked depending on the value of the :py:class:`.Pipeline`'s `parallelize_processing` flag. + """ + results = await asyncio.gather( + *[wrap_sync_function_in_async(func, ctx, pipeline) for func in processing.values()], + return_exceptions=True, ) + for exc, (processing_name, processing_func) in zip(results, processing.items()): + if isinstance(exc, Exception): + logger.error( + f"Exception {exc} for processing_name={processing_name} and processing_func={processing_func}", + exc_info=exc, + ) + + async def _run_processing_sequential(self, processing: dict, ctx: Context, pipeline: Pipeline) -> None: + """ + Execute the processing functions for a particular node in-order. + + Picked depending on the value of the :py:class:`.Pipeline`'s `parallelize_processing` flag. + """ + for processing_name, processing_func in processing.items(): + try: + await wrap_sync_function_in_async(processing_func, ctx, pipeline) + except Exception as exc: + logger.error( + f"Exception {exc} for processing_name={processing_name} and processing_func={processing_func}", + exc_info=exc, + ) + + async def _run_pre_transitions_processing(self, ctx: Context, pipeline: Pipeline) -> None: + """ + Run `PRE_TRANSITIONS_PROCESSING` functions for a particular node. + Pre-transition processing functions can modify the context state + before the direction of the next transition is determined depending on that state. + + The execution order depends on the value of the :py:class:`.Pipeline`'s + `parallelize_processing` flag. + """ + ctx.framework_states["actor"]["processed_node"] = copy.deepcopy(ctx.framework_states["actor"]["previous_node"]) + pre_transitions_processing = ctx.framework_states["actor"]["previous_node"].pre_transitions_processing + + if pipeline.parallelize_processing: + await self._run_processing_parallel(pre_transitions_processing, ctx, pipeline) + else: + await self._run_processing_sequential(pre_transitions_processing, ctx, pipeline) + ctx.framework_states["actor"]["pre_transitions_processed_node"] = ctx.framework_states["actor"][ "processed_node" ] del ctx.framework_states["actor"]["processed_node"] - return ctx - def _run_pre_response_processing(self, ctx: Context, pipeline: Pipeline, *args, **kwargs) -> Context: + async def _run_pre_response_processing(self, ctx: Context, pipeline: Pipeline) -> None: + """ + Run `PRE_RESPONSE_PROCESSING` functions for a particular node. + Pre-response processing functions can modify the response before it is + returned to the user. + + The execution order depends on the value of the :py:class:`.Pipeline`'s + `parallelize_processing` flag. + """ ctx.framework_states["actor"]["processed_node"] = copy.deepcopy(ctx.framework_states["actor"]["next_node"]) - ctx = ctx.framework_states["actor"]["next_node"].run_pre_response_processing(ctx, pipeline, *args, **kwargs) + pre_response_processing = ctx.framework_states["actor"]["next_node"].pre_response_processing + + if pipeline.parallelize_processing: + await self._run_processing_parallel(pre_response_processing, ctx, pipeline) + else: + await self._run_processing_sequential(pre_response_processing, ctx, pipeline) + ctx.framework_states["actor"]["pre_response_processed_node"] = ctx.framework_states["actor"]["processed_node"] del ctx.framework_states["actor"]["processed_node"] - return ctx - def _get_true_label( + async def _get_true_label( self, transitions: dict, ctx: Context, pipeline: Pipeline, flow_label: LabelType, transition_info: str = "", - *args, - **kwargs, ) -> Optional[NodeLabel3Type]: true_labels = [] - for label, condition in transitions.items(): - if self.condition_handler(condition, ctx, pipeline, *args, **kwargs): + + cond_booleans = await asyncio.gather( + *(self.condition_handler(condition, ctx, pipeline) for condition in transitions.values()) + ) + for label, cond_is_true in zip(transitions.keys(), cond_booleans): + if cond_is_true: if callable(label): - label = label(ctx, pipeline, *args, **kwargs) + label = await wrap_sync_function_in_async(label, ctx, pipeline) # TODO: explicit handling of errors if label is None: continue - label = normalize_label(label, flow_label) true_labels += [label] true_labels = [ ((label[0] if label[0] else flow_label),) @@ -311,8 +370,10 @@ def _get_true_label( logger.debug(f"{transition_info} transitions sorted by priority = {true_labels}") return true_label - def _run_handlers(self, ctx, pipeline: Pipeline, actor_stage: ActorStage, *args, **kwargs): - [handler(ctx, pipeline, *args, **kwargs) for handler in self.handlers.get(actor_stage, [])] + async def _run_handlers(self, ctx, pipeline: Pipeline, actor_stage: ActorStage): + stage_handlers = self.handlers.get(actor_stage, []) + async_handlers = [wrap_sync_function_in_async(handler, ctx, pipeline) for handler in stage_handlers] + await asyncio.gather(*async_handlers) def _choose_label( self, specific_label: Optional[NodeLabel3Type], general_label: Optional[NodeLabel3Type] @@ -360,7 +421,7 @@ def validate_script(self, pipeline: Pipeline, verbose: bool = True): # validate responsing response_func = normalize_response(node.response) try: - response_result = response_func(ctx, pipeline) + response_result = asyncio.run(wrap_sync_function_in_async(response_func, ctx, pipeline)) if not isinstance(response_result, Message): msg = ( "Expected type of response_result is `Message`.\n" @@ -390,9 +451,9 @@ def validate_script(self, pipeline: Pipeline, verbose: bool = True): return error_msgs -def default_condition_handler( - condition: Callable, ctx: Context, pipeline: Pipeline, *args, **kwargs -) -> Callable[[Context, Pipeline, Any, Any], bool]: +async def default_condition_handler( + condition: Callable, ctx: Context, pipeline: Pipeline +) -> Callable[[Context, Pipeline], bool]: """ The simplest and quickest condition handler for trivial condition handling returns the callable condition: @@ -400,4 +461,4 @@ def default_condition_handler( :param ctx: Context of current condition. :param pipeline: Pipeline we use in this condition. """ - return condition(ctx, pipeline, *args, **kwargs) + return await wrap_sync_function_in_async(condition, ctx, pipeline) diff --git a/dff/pipeline/pipeline/component.py b/dff/pipeline/pipeline/component.py index 2291a4d55..d58ddb39d 100644 --- a/dff/pipeline/pipeline/component.py +++ b/dff/pipeline/pipeline/component.py @@ -12,7 +12,7 @@ import abc import asyncio import copy -from typing import Optional, Union, Awaitable, ForwardRef +from typing import Optional, Awaitable, ForwardRef from dff.script import Context @@ -163,27 +163,24 @@ async def run_extra_handler(self, stage: ExtraHandlerType, ctx: Context, pipelin logger.warning(f"{type(self).__name__} '{self.name}' {extra_handler.stage} extra handler timed out!") @abc.abstractmethod - async def _run(self, ctx: Context, pipeline: Optional[Pipeline] = None) -> Optional[Context]: + async def _run(self, ctx: Context, pipeline: Pipeline) -> None: """ A method for running pipeline component, it is overridden in all its children. This method is run after the component's timeout is set (if needed). :param ctx: Current dialog :py:class:`~.Context`. :param pipeline: This :py:class:`~.Pipeline`. - :return: :py:class:`~.Context` if this is a synchronous service or `None`, - asynchronous services shouldn't modify :py:class:`~.Context`. """ raise NotImplementedError - async def __call__(self, ctx: Context, pipeline: Optional[Pipeline] = None) -> Optional[Union[Context, Awaitable]]: + async def __call__(self, ctx: Context, pipeline: Pipeline) -> Optional[Awaitable]: """ A method for calling pipeline components. It sets up timeout if this component is asynchronous and executes it using :py:meth:`~._run` method. :param ctx: Current dialog :py:class:`~.Context`. :param pipeline: This :py:class:`~.Pipeline`. - :return: :py:class:`~.Context` if this is a synchronous service or :py:class:`~.typing.const.Awaitable`, - asynchronous services shouldn't modify :py:class:`~.Context`. + :return: `None` if the service is synchronous; an `Awaitable` otherwise. """ if self.asynchronous: task = asyncio.create_task(self._run(ctx, pipeline)) diff --git a/dff/pipeline/pipeline/pipeline.py b/dff/pipeline/pipeline/pipeline.py index 49864f7d1..fb548e9f5 100644 --- a/dff/pipeline/pipeline/pipeline.py +++ b/dff/pipeline/pipeline/pipeline.py @@ -74,6 +74,9 @@ class Pipeline: - `_services_pipeline` is a pipeline root :py:class:`~.ServiceGroup` object, - `actor` is a pipeline actor, found among services. + :param parallelize_processing: This flag determines whether or not the functions + defined in the ``PRE_RESPONSE_PROCESSING`` and ``PRE_TRANSITIONS_PROCESSING`` sections + of the script should be parallelized over respective groups. """ @@ -94,6 +97,7 @@ def __init__( after_handler: Optional[ExtraHandlerBuilder] = None, timeout: Optional[float] = None, optimization_warnings: bool = False, + parallelize_processing: bool = False, ): self.actor: Actor = None self.messenger_interface = CLIMessengerInterface() if messenger_interface is None else messenger_interface @@ -109,7 +113,7 @@ def __init__( self._services_pipeline.path = ".pipeline" actor_exists = finalize_service_group(self._services_pipeline, path=self._services_pipeline.path) if not actor_exists: - raise Exception("Actor not found in pipeline!") + raise Exception("Actor not found in the pipeline!") else: self.set_actor( script, @@ -127,6 +131,8 @@ def __init__( if optimization_warnings: self._services_pipeline.log_optimization_warnings() + self.parallelize_processing = parallelize_processing + # NB! The following API is highly experimental and may be removed at ANY time WITHOUT FURTHER NOTICE!! self._clean_turn_cache = True if self._clean_turn_cache: @@ -207,6 +213,7 @@ def from_script( validation_stage: Optional[bool] = None, condition_handler: Optional[Callable] = None, verbose: bool = True, + parallelize_processing: bool = False, handlers: Optional[Dict[ActorStage, List[Callable]]] = None, context_storage: Optional[Union[DBContextStorage, Dict]] = None, messenger_interface: Optional[MessengerInterface] = None, @@ -230,6 +237,9 @@ def from_script( It is executed by default. Defaults to `None`. :param condition_handler: Handler that processes a call of actor condition functions. Defaults to `None`. :param verbose: If it is `True`, logging is used in actor. Defaults to `True`. + :param parallelize_processing: This flag determines whether or not the functions + defined in the ``PRE_RESPONSE_PROCESSING`` and ``PRE_TRANSITIONS_PROCESSING`` sections + of the script should be parallelized over respective groups. :param handlers: This variable is responsible for the usage of external handlers on the certain stages of work of :py:class:`~dff.script.Actor`. @@ -257,6 +267,7 @@ def from_script( validation_stage=validation_stage, condition_handler=condition_handler, verbose=verbose, + parallelize_processing=parallelize_processing, handlers=handlers, messenger_interface=messenger_interface, context_storage=context_storage, @@ -314,13 +325,12 @@ def from_dict(cls, dictionary: PipelineBuilder) -> "Pipeline": """ return cls(**dictionary) - async def _run_pipeline(self, request: Message, ctx_id: Optional[Hashable] = None) -> Context: + async def _run_pipeline( + self, request: Message, ctx_id: Optional[Hashable] = None, update_ctx_misc: Optional[dict] = None + ) -> Context: """ - Method that runs pipeline once for user request. - - :param request: (required) Any user request. - :param ctx_id: Current dialog id; if `None`, new dialog will be created. - :return: Dialog `Context`. + Method that should be invoked on user input. + This method has the same signature as :py:class:`~dff.pipeline.types.PipelineRunnerFunction`. """ if ctx_id is None: ctx = Context() @@ -329,9 +339,16 @@ async def _run_pipeline(self, request: Message, ctx_id: Optional[Hashable] = Non else: ctx = self.context_storage.get(ctx_id, Context(id=ctx_id)) + if update_ctx_misc is not None: + ctx.misc.update(update_ctx_misc) + ctx.framework_states[PIPELINE_STATE_KEY] = {} ctx.add_request(request) - ctx = await self._services_pipeline(ctx, self) + result = await self._services_pipeline(ctx, self) + + if asyncio.iscoroutine(result): + await result + del ctx.framework_states[PIPELINE_STATE_KEY] if isinstance(self.context_storage, DBContextStorage): @@ -353,17 +370,17 @@ def run(self): """ asyncio.run(self.messenger_interface.connect(self._run_pipeline)) - def __call__(self, request: Message, ctx_id: Hashable) -> Context: + def __call__( + self, request: Message, ctx_id: Optional[Hashable] = None, update_ctx_misc: Optional[dict] = None + ) -> Context: """ Method that executes pipeline once. Basically, it is a shortcut for `_run_pipeline`. NB! When pipeline is executed this way, `messenger_interface` won't be initiated nor connected. - :param request: Any user request. - :param ctx_id: Current dialog id. - :return: Dialog `Context`. + This method has the same signature as :py:class:`~dff.pipeline.types.PipelineRunnerFunction`. """ - return asyncio.run(self._run_pipeline(request, ctx_id)) + return asyncio.run(self._run_pipeline(request, ctx_id, update_ctx_misc)) @property def script(self) -> Script: diff --git a/dff/pipeline/service/group.py b/dff/pipeline/service/group.py index 1b1591b0f..1f845cde1 100644 --- a/dff/pipeline/service/group.py +++ b/dff/pipeline/service/group.py @@ -40,7 +40,6 @@ class ServiceGroup(PipelineComponent): Components in synchronous groups are executed consequently (no matter is they are synchronous or asynchronous). Components in asynchronous groups are executed simultaneously. Group can be asynchronous only if all components in it are asynchronous. - Group containing actor can be synchronous only. :param components: A `ServiceGroupBuilder` object, that will be added to the group. :type components: :py:data:`~.ServiceGroupBuilder` @@ -97,7 +96,7 @@ def __init__( else: raise Exception(f"Unknown type for ServiceGroup {components}") - async def _run_services_group(self, ctx: Context, pipeline: Pipeline) -> Context: + async def _run_services_group(self, ctx: Context, pipeline: Pipeline) -> None: """ Method for running this service group. It doesn't include wrappers execution, start condition checking or error handling - pure execution only. @@ -107,7 +106,6 @@ async def _run_services_group(self, ctx: Context, pipeline: Pipeline) -> Context :param ctx: Current dialog context. :param pipeline: The current pipeline. - :return: Current dialog context. """ self._set_state(ctx, ComponentExecutionState.RUNNING) @@ -123,33 +121,29 @@ async def _run_services_group(self, ctx: Context, pipeline: Pipeline) -> Context else: for service in self.components: service_result = await service(ctx, pipeline) - if not service.asynchronous and isinstance(service_result, Context): - ctx = service_result - elif service.asynchronous and isinstance(service_result, Awaitable): + if service.asynchronous and isinstance(service_result, Awaitable): await service_result failed = any([service.get_state(ctx) == ComponentExecutionState.FAILED for service in self.components]) self._set_state(ctx, ComponentExecutionState.FAILED if failed else ComponentExecutionState.FINISHED) - return ctx async def _run( self, ctx: Context, - pipeline: Pipeline = None, - ) -> Optional[Context]: + pipeline: Pipeline, + ) -> None: """ Method for handling this group execution. Executes before and after execution wrappers, checks start condition and catches runtime exceptions. :param ctx: Current dialog context. :param pipeline: The current pipeline. - :return: Current dialog context if synchronous, else `None`. """ await self.run_extra_handler(ExtraHandlerType.BEFORE, ctx, pipeline) try: if self.start_condition(ctx, pipeline): - ctx = await self._run_services_group(ctx, pipeline) + await self._run_services_group(ctx, pipeline) else: self._set_state(ctx, ComponentExecutionState.NOT_RUN) @@ -158,7 +152,6 @@ async def _run( logger.error(f"ServiceGroup '{self.name}' execution failed!\n{e}") await self.run_extra_handler(ExtraHandlerType.AFTER, ctx, pipeline) - return ctx if not self.asynchronous else None def log_optimization_warnings(self): """ diff --git a/dff/pipeline/service/service.py b/dff/pipeline/service/service.py index 406ea3a81..76787143e 100644 --- a/dff/pipeline/service/service.py +++ b/dff/pipeline/service/service.py @@ -7,12 +7,9 @@ Service group can be synchronous or asynchronous. Service is an atomic part of a pipeline. Service can be asynchronous only if its handler is a coroutine. -Actor wrapping service can be synchronous only. +Actor wrapping service is asynchronous. """ -# TODO: change last sentence, when actor will be asynchronous - import logging -import asyncio import inspect from typing import Optional, ForwardRef @@ -39,7 +36,6 @@ class Service(PipelineComponent): Service can be included into pipeline as object or a dictionary. Service group can be synchronous or asynchronous. Service can be asynchronous only if its handler is a coroutine. - Actor wrapping service can be synchronous only. :param handler: A service function or an actor. :type handler: :py:data:`~.ServiceBuilder` @@ -91,15 +87,15 @@ def __init__( before_handler, after_handler, timeout, - asynchronous, - asyncio.iscoroutinefunction(handler), + True, + True, start_condition, name, ) else: raise Exception(f"Unknown type of service handler: {handler}") - async def _run_handler(self, ctx: Context, pipeline: Pipeline): + async def _run_handler(self, ctx: Context, pipeline: Pipeline) -> None: """ Method for service `handler` execution. Handler has three possible signatures, so this method picks the right one to invoke. @@ -124,30 +120,27 @@ async def _run_handler(self, ctx: Context, pipeline: Pipeline): else: raise Exception(f"Too many parameters required for service '{self.name}' handler: {handler_params}!") - def _run_as_actor(self, ctx: Context, pipeline: Pipeline): + async def _run_as_actor(self, ctx: Context, pipeline: Pipeline) -> None: """ Method for running this service if its handler is an `Actor`. Catches runtime exceptions. :param ctx: Current dialog context. - :return: Context, mutated by actor. """ try: - ctx = pipeline.actor(pipeline, ctx) + await pipeline.actor(pipeline, ctx) self._set_state(ctx, ComponentExecutionState.FINISHED) except Exception as exc: self._set_state(ctx, ComponentExecutionState.FAILED) logger.error(f"Actor '{self.name}' execution failed!\n{exc}") - return ctx - async def _run_as_service(self, ctx: Context, pipeline: Pipeline): + async def _run_as_service(self, ctx: Context, pipeline: Pipeline) -> None: """ Method for running this service if its handler is not an Actor. Checks start condition and catches runtime exceptions. :param ctx: Current dialog context. :param pipeline: Current pipeline. - :return: `None` """ try: if self.start_condition(ctx, pipeline): @@ -160,28 +153,23 @@ async def _run_as_service(self, ctx: Context, pipeline: Pipeline): self._set_state(ctx, ComponentExecutionState.FAILED) logger.error(f"Service '{self.name}' execution failed!\n{e}") - async def _run(self, ctx: Context, pipeline: Optional[Pipeline] = None) -> Optional[Context]: + async def _run(self, ctx: Context, pipeline: Pipeline) -> None: """ Method for handling this service execution. Executes before and after execution wrappers, launches `_run_as_actor` or `_run_as_service` method. :param ctx: (required) Current dialog context. :param pipeline: the current pipeline. - :return: `Context` if this service's handler is an `Actor` else `None`. """ await self.run_extra_handler(ExtraHandlerType.BEFORE, ctx, pipeline) if isinstance(self.handler, str) and self.handler == "ACTOR": - ctx = self._run_as_actor(ctx, pipeline) + await self._run_as_actor(ctx, pipeline) else: await self._run_as_service(ctx, pipeline) await self.run_extra_handler(ExtraHandlerType.AFTER, ctx, pipeline) - if isinstance(self.handler, str) and self.handler == "ACTOR": - return ctx - return None - @property def info_dict(self) -> dict: """ diff --git a/dff/pipeline/types.py b/dff/pipeline/types.py index 39584a303..ef7bef9ae 100644 --- a/dff/pipeline/types.py +++ b/dff/pipeline/types.py @@ -7,10 +7,10 @@ """ from abc import ABC from enum import unique, Enum -from typing import Callable, Union, Awaitable, Dict, List, Optional, NewType, Iterable, Any +from typing import Callable, Union, Awaitable, Dict, List, Optional, NewType, Iterable, Any, Protocol, Hashable from dff.context_storages import DBContextStorage -from dff.script import Context, ActorStage, NodeLabel2Type, Script +from dff.script import Context, ActorStage, NodeLabel2Type, Script, Message from typing_extensions import NotRequired, TypedDict, TypeAlias from pydantic import BaseModel @@ -25,6 +25,32 @@ _ForwardExtraHandlerRuntimeInfo = NewType("ExtraHandlerRuntimeInfo", Any) +class PipelineRunnerFunction(Protocol): + """ + Protocol for pipeline running. + """ + + def __call__( + self, message: Message, ctx_id: Optional[Hashable] = None, update_ctx_misc: Optional[dict] = None + ) -> Context: + """ + :param message: User request for pipeline to process. + :param ctx_id: + ID of the context that the new request belongs to. + Optional, None by default. + If set to `None`, a new context will be created with `message` being the first request. + :param update_ctx_misc: + Dictionary to be passed as an argument to `ctx.misc.update`. + This argument can be used to store values in the `misc` dictionary before anything else runs. + Optional; None by default. + If set to `None`, `ctx.misc.update` will not be called. + :return: + Context instance that pipeline processed. + The context instance has the id of `ctx_id`. + If `ctx_id` is `None`, context instance has an id generated with `uuid.uuid4`. + """ + + @unique class ComponentExecutionState(str, Enum): """ @@ -234,6 +260,7 @@ class ExtraHandlerRuntimeInfo(BaseModel): "before_handler": NotRequired[Optional[ExtraHandlerBuilder]], "after_handler": NotRequired[Optional[ExtraHandlerBuilder]], "optimization_warnings": NotRequired[bool], + "parallelize_processing": NotRequired[bool], "script": Union[Script, Dict], "start_label": NodeLabel2Type, "fallback_label": NotRequired[Optional[NodeLabel2Type]], diff --git a/dff/script/conditions/std_conditions.py b/dff/script/conditions/std_conditions.py index 4dcce993c..bd4186aa9 100644 --- a/dff/script/conditions/std_conditions.py +++ b/dff/script/conditions/std_conditions.py @@ -8,7 +8,7 @@ These conditions can be used to check the current context, the user's input, or other factors that may affect the conversation flow. """ -from typing import Callable, Pattern, Union, Any, List, Optional +from typing import Callable, Pattern, Union, List, Optional import logging import re @@ -21,7 +21,7 @@ @validate_call -def exact_match(match: Message, skip_none: bool = True) -> Callable[..., bool]: +def exact_match(match: Message, skip_none: bool = True) -> Callable[[Context, Pipeline], bool]: """ Return function handler. This handler returns `True` only if the last user phrase is the same Message as the :py:const:`match`. @@ -31,7 +31,7 @@ def exact_match(match: Message, skip_none: bool = True) -> Callable[..., bool]: :param skip_none: Whether fields should be compared if they are `None` in :py:const:`match`. """ - def exact_match_condition_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> bool: + def exact_match_condition_handler(ctx: Context, pipeline: Pipeline) -> bool: request = ctx.last_request if request is None: return False @@ -50,9 +50,7 @@ def exact_match_condition_handler(ctx: Context, pipeline: Pipeline, *args, **kwa @validate_call -def regexp( - pattern: Union[str, Pattern], flags: Union[int, re.RegexFlag] = 0 -) -> Callable[[Context, Pipeline, Any, Any], bool]: +def regexp(pattern: Union[str, Pattern], flags: Union[int, re.RegexFlag] = 0) -> Callable[[Context, Pipeline], bool]: """ Return function handler. This handler returns `True` only if the last user phrase contains :py:const:`pattern ` with :py:const:`flags `. @@ -62,7 +60,7 @@ def regexp( """ pattern = re.compile(pattern, flags) - def regexp_condition_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> bool: + def regexp_condition_handler(ctx: Context, pipeline: Pipeline) -> bool: request = ctx.last_request if isinstance(request, Message): if request.text is None: @@ -98,7 +96,7 @@ def check_cond_seq(cond_seq: list): @validate_call -def aggregate(cond_seq: list, aggregate_func: Callable = _any) -> Callable[[Context, Pipeline, Any, Any], bool]: +def aggregate(cond_seq: list, aggregate_func: Callable = _any) -> Callable[[Context, Pipeline], bool]: """ Aggregate multiple functions into one by using aggregating function. @@ -107,9 +105,9 @@ def aggregate(cond_seq: list, aggregate_func: Callable = _any) -> Callable[[Cont """ check_cond_seq(cond_seq) - def aggregate_condition_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> bool: + def aggregate_condition_handler(ctx: Context, pipeline: Pipeline) -> bool: try: - return bool(aggregate_func([cond(ctx, pipeline, *args, **kwargs) for cond in cond_seq])) + return bool(aggregate_func([cond(ctx, pipeline) for cond in cond_seq])) except Exception as exc: logger.error(f"Exception {exc} for {cond_seq}, {aggregate_func} and {ctx.last_request}", exc_info=exc) return False @@ -118,7 +116,7 @@ def aggregate_condition_handler(ctx: Context, pipeline: Pipeline, *args, **kwarg @validate_call -def any(cond_seq: list) -> Callable[[Context, Pipeline, Any, Any], bool]: +def any(cond_seq: list) -> Callable[[Context, Pipeline], bool]: """ Return function handler. This handler returns `True` if any function from the list is `True`. @@ -127,14 +125,14 @@ def any(cond_seq: list) -> Callable[[Context, Pipeline, Any, Any], bool]: """ _agg = aggregate(cond_seq, _any) - def any_condition_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> bool: - return _agg(ctx, pipeline, *args, **kwargs) + def any_condition_handler(ctx: Context, pipeline: Pipeline) -> bool: + return _agg(ctx, pipeline) return any_condition_handler @validate_call -def all(cond_seq: list) -> Callable[[Context, Pipeline, Any, Any], bool]: +def all(cond_seq: list) -> Callable[[Context, Pipeline], bool]: """ Return function handler. This handler returns `True` only if all functions from the list are `True`. @@ -143,14 +141,14 @@ def all(cond_seq: list) -> Callable[[Context, Pipeline, Any, Any], bool]: """ _agg = aggregate(cond_seq, _all) - def all_condition_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> bool: - return _agg(ctx, pipeline, *args, **kwargs) + def all_condition_handler(ctx: Context, pipeline: Pipeline) -> bool: + return _agg(ctx, pipeline) return all_condition_handler @validate_call -def negation(condition: Callable) -> Callable[[Context, Pipeline, Any, Any], bool]: +def negation(condition: Callable) -> Callable[[Context, Pipeline], bool]: """ Return function handler. This handler returns negation of the :py:func:`~condition`: `False` if :py:func:`~condition` holds `True` and returns `True` otherwise. @@ -158,8 +156,8 @@ def negation(condition: Callable) -> Callable[[Context, Pipeline, Any, Any], boo :param condition: Any :py:func:`~condition`. """ - def negation_condition_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> bool: - return not condition(ctx, pipeline, *args, **kwargs) + def negation_condition_handler(ctx: Context, pipeline: Pipeline) -> bool: + return not condition(ctx, pipeline) return negation_condition_handler @@ -169,7 +167,7 @@ def has_last_labels( flow_labels: Optional[List[str]] = None, labels: Optional[List[NodeLabel2Type]] = None, last_n_indices: int = 1, -) -> Callable[[Context, Pipeline, Any, Any], bool]: +) -> Callable[[Context, Pipeline], bool]: """ Return condition handler. This handler returns `True` if any label from last :py:const:`last_n_indices` context labels is in @@ -183,7 +181,7 @@ def has_last_labels( flow_labels = [] if flow_labels is None else flow_labels labels = [] if labels is None else labels - def has_last_labels_condition_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> bool: + def has_last_labels_condition_handler(ctx: Context, pipeline: Pipeline) -> bool: label = list(ctx.labels.values())[-last_n_indices:] for label in list(ctx.labels.values())[-last_n_indices:]: label = label if label else (None, None) @@ -195,24 +193,24 @@ def has_last_labels_condition_handler(ctx: Context, pipeline: Pipeline, *args, * @validate_call -def true() -> Callable[[Context, Pipeline, Any, Any], bool]: +def true() -> Callable[[Context, Pipeline], bool]: """ Return function handler. This handler always returns `True`. """ - def true_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> bool: + def true_handler(ctx: Context, pipeline: Pipeline) -> bool: return True return true_handler @validate_call -def false() -> Callable[[Context, Pipeline, Any, Any], bool]: +def false() -> Callable[[Context, Pipeline], bool]: """ Return function handler. This handler always returns `False`. """ - def false_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> bool: + def false_handler(ctx: Context, pipeline: Pipeline) -> bool: return False return false_handler diff --git a/dff/script/core/context.py b/dff/script/core/context.py index 78ee18072..730beec2f 100644 --- a/dff/script/core/context.py +++ b/dff/script/core/context.py @@ -18,10 +18,10 @@ """ import logging from uuid import UUID, uuid4 - from typing import Any, Optional, Union, Dict, List, Set from pydantic import BaseModel, Field, field_validator + from .types import NodeLabel2Type, ModuleName from .message import Message @@ -278,23 +278,5 @@ def current_node(self) -> Optional[Node]: return node - def overwrite_current_node_in_processing(self, processed_node: Node): - """ - Set the current node to be `processed_node`. - This method only works in processing functions (pre-response and pre-transition). - - The actual current node is not changed. - - :param processed_node: `node` to set as the current node. - """ - is_processing = self.framework_states.get("actor", {}).get("processed_node") - if is_processing: - self.framework_states["actor"]["processed_node"] = Node.model_validate(processed_node) - else: - logger.warning( - f"The `{self.overwrite_current_node_in_processing.__name__}` " - "method can only be called from processing functions (either pre-response or pre-transition)." - ) - Context.model_rebuild() diff --git a/dff/script/core/keywords.py b/dff/script/core/keywords.py index 076577ae4..805c10db5 100644 --- a/dff/script/core/keywords.py +++ b/dff/script/core/keywords.py @@ -64,8 +64,8 @@ class Keywords(str, Enum): `{"PRE_RESPONSE_PROC_0": pre_response_proc_func_0, ..., "PRE_RESPONSE_PROC_N": pre_response_proc__func_N}`, where `"PRE_RESPONSE_PROC_i"` is an arbitrary name of the preprocessing stage in the pipeline. - The order of `pre_response_proc__func_i` calls is defined by the order - in which the preprocessing `dict` is defined. + Unless the :py:class:`~dff.pipeline.pipeline.Pipeline`'s `parallelize_processing` flag + is set to `True`, calls to `pre_response_proc__func_i` are made in-order. PRE_TRANSITIONS_PROCESSING: Enum(auto) The keyword specifying the preprocessing that is called before the transition. @@ -75,8 +75,8 @@ class Keywords(str, Enum): "PRE_TRANSITIONS_PROC_N": pre_transitions_proc_func_N}`, where `"PRE_TRANSITIONS_PROC_i"` is an arbitrary name of the preprocessing stage in the pipeline. - The order of `pre_transitions_proc_func_i` calls is defined by the order - in which the preprocessing `dict` is defined. + Unless the :py:class:`~dff.pipeline.pipeline.Pipeline`'s `parallelize_processing` flag + is set to `True`, calls to `pre_transitions_proc_func_i` are made in-order. """ diff --git a/dff/script/core/normalization.py b/dff/script/core/normalization.py index a0f91407f..ef9f75419 100644 --- a/dff/script/core/normalization.py +++ b/dff/script/core/normalization.py @@ -7,7 +7,7 @@ """ import logging -from typing import Union, Callable, Any, Dict, Optional, ForwardRef +from typing import Union, Callable, Optional, ForwardRef from .keywords import Keywords from .context import Context @@ -21,7 +21,9 @@ Pipeline = ForwardRef("Pipeline") -def normalize_label(label: NodeLabelType, default_flow_label: LabelType = "") -> Union[Callable, NodeLabel3Type]: +def normalize_label( + label: NodeLabelType, default_flow_label: LabelType = "" +) -> Union[Callable[[Context, Pipeline], NodeLabel3Type], NodeLabel3Type]: """ The function that is used for normalization of :py:const:`default_flow_label `. @@ -34,9 +36,9 @@ def normalize_label(label: NodeLabelType, default_flow_label: LabelType = "") -> """ if callable(label): - def get_label_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> NodeLabel3Type: + def get_label_handler(ctx: Context, pipeline: Pipeline) -> NodeLabel3Type: try: - new_label = label(ctx, pipeline, *args, **kwargs) + new_label = label(ctx, pipeline) new_label = normalize_label(new_label, default_flow_label) flow_label, node_label, _ = new_label node = pipeline.script.get(flow_label, {}).get(node_label) @@ -62,7 +64,7 @@ def get_label_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> Node return (flow_label, label[1], label[2]) -def normalize_condition(condition: ConditionType) -> Callable: +def normalize_condition(condition: ConditionType) -> Callable[[Context, Pipeline], bool]: """ The function that is used to normalize `condition` @@ -71,9 +73,9 @@ def normalize_condition(condition: ConditionType) -> Callable: """ if callable(condition): - def callable_condition_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> bool: + def callable_condition_handler(ctx: Context, pipeline: Pipeline) -> bool: try: - return condition(ctx, pipeline, *args, **kwargs) + return condition(ctx, pipeline) except Exception as exc: logger.error(f"Exception {exc} of function {condition}", exc_info=exc) return False @@ -82,10 +84,12 @@ def callable_condition_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs @validate_call -def normalize_response(response: Optional[Union[Message, Callable[..., Message]]]) -> Callable[..., Message]: +def normalize_response( + response: Optional[Union[Message, Callable[[Context, Pipeline], Message]]] +) -> Callable[[Context, Pipeline], Message]: """ - This function is used to normalize response, if response Callable, it is returned, otherwise - response is wrapped to the function and this function is returned. + This function is used to normalize response. If the response is a Callable, it is returned, otherwise + the response is wrapped in an asynchronous function and this function is returned. :param response: Response to normalize. :return: Function that returns callable response. @@ -100,33 +104,7 @@ def normalize_response(response: Optional[Union[Message, Callable[..., Message]] else: raise TypeError(type(response)) - def response_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs): + async def response_handler(ctx: Context, pipeline: Pipeline): return result return response_handler - - -@validate_call -def normalize_processing(processing: Dict[Any, Callable]) -> Callable: - """ - This function is used to normalize processing. - It returns function that consecutively applies all preprocessing stages from dict. - - :param processing: Processing which contains all preprocessing stages in a format "PROC_i" -> proc_func_i. - :return: Function that consequentially applies all preprocessing stages from dict. - """ - if isinstance(processing, dict): - - def processing_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> Context: - for processing_name, processing_func in processing.items(): - try: - if processing_func is not None: - ctx = processing_func(ctx, pipeline, *args, **kwargs) - except Exception as exc: - logger.error( - f"Exception {exc} for processing_name={processing_name} and processing_func={processing_func}", - exc_info=exc, - ) - return ctx - - return processing_handler diff --git a/dff/script/core/script.py b/dff/script/core/script.py index 31fc7c9a5..7896c415e 100644 --- a/dff/script/core/script.py +++ b/dff/script/core/script.py @@ -15,7 +15,7 @@ from .types import LabelType, NodeLabelType, ConditionType, NodeLabel3Type from .message import Message from .keywords import Keywords -from .normalization import normalize_response, normalize_processing, normalize_condition, normalize_label, validate_call +from .normalization import normalize_condition, normalize_label, validate_call from typing import ForwardRef logger = logging.getLogger(__name__) @@ -25,7 +25,7 @@ Context = ForwardRef("Context") -class Node(BaseModel, extra="forbid"): +class Node(BaseModel, extra="forbid", validate_assignment=True): """ The class for the `Node` object. """ @@ -53,36 +53,6 @@ def normalize_transitions( } return transitions - def run_response(self, ctx: Context, pipeline: Pipeline, *args, **kwargs) -> Context: - """ - Executes the normalized response. - See details in the :py:func:`~normalize_response` function of `normalization.py`. - """ - response = normalize_response(self.response) - return response(ctx, pipeline, *args, **kwargs) - - def run_pre_response_processing(self, ctx: Context, pipeline: Pipeline, *args, **kwargs) -> Context: - """ - Executes pre-processing of responses. - """ - return self.run_processing(self.pre_response_processing, ctx, pipeline, *args, **kwargs) - - def run_pre_transitions_processing(self, ctx: Context, pipeline: Pipeline, *args, **kwargs) -> Context: - """ - Executes pre-processing of transitions. - """ - return self.run_processing(self.pre_transitions_processing, ctx, pipeline, *args, **kwargs) - - def run_processing( - self, processing: Dict[Any, Callable], ctx: Context, pipeline: Pipeline, *args, **kwargs - ) -> Context: - """ - Executes the normalized processing. - See details in the :py:func:`~normalize_processing` function of `normalization.py`. - """ - processing = normalize_processing(processing) - return processing(ctx, pipeline, *args, **kwargs) - class Script(BaseModel, extra="forbid"): """ diff --git a/dff/script/labels/std_labels.py b/dff/script/labels/std_labels.py index 0e250c245..1c2c03322 100644 --- a/dff/script/labels/std_labels.py +++ b/dff/script/labels/std_labels.py @@ -27,7 +27,7 @@ def repeat(priority: Optional[float] = None) -> Callable: :param priority: Priority of transition. Uses `Pipeline.actor.label_priority` if priority not defined. """ - def repeat_transition_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> NodeLabel3Type: + def repeat_transition_handler(ctx: Context, pipeline: Pipeline) -> NodeLabel3Type: current_priority = pipeline.actor.label_priority if priority is None else priority if len(ctx.labels) >= 1: flow_label, label = list(ctx.labels.values())[-1] @@ -50,7 +50,7 @@ def previous(priority: Optional[float] = None) -> Callable: :param priority: Priority of transition. Uses `Pipeline.actor.label_priority` if priority not defined. """ - def previous_transition_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> NodeLabel3Type: + def previous_transition_handler(ctx: Context, pipeline: Pipeline) -> NodeLabel3Type: current_priority = pipeline.actor.label_priority if priority is None else priority if len(ctx.labels) >= 2: flow_label, label = list(ctx.labels.values())[-2] @@ -74,7 +74,7 @@ def to_start(priority: Optional[float] = None) -> Callable: :param priority: Priority of transition. Uses `Pipeline.actor.label_priority` if priority not defined. """ - def to_start_transition_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> NodeLabel3Type: + def to_start_transition_handler(ctx: Context, pipeline: Pipeline) -> NodeLabel3Type: current_priority = pipeline.actor.label_priority if priority is None else priority return (*pipeline.actor.start_label[:2], current_priority) @@ -92,7 +92,7 @@ def to_fallback(priority: Optional[float] = None) -> Callable: :param priority: Priority of transition. Uses `Pipeline.actor.label_priority` if priority not defined. """ - def to_fallback_transition_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> NodeLabel3Type: + def to_fallback_transition_handler(ctx: Context, pipeline: Pipeline) -> NodeLabel3Type: current_priority = pipeline.actor.label_priority if priority is None else priority return (*pipeline.actor.fallback_label[:2], current_priority) @@ -105,8 +105,6 @@ def _get_label_by_index_shifting( priority: Optional[float] = None, increment_flag: bool = True, cyclicality_flag: bool = True, - *args, - **kwargs, ) -> NodeLabel3Type: """ Function that returns node label from the context and pipeline after shifting the index. @@ -121,7 +119,7 @@ def _get_label_by_index_shifting( :return: The tuple that consists of `(flow_label, label, priority)`. If fallback is executed `(flow_fallback_label, fallback_label, priority)` are returned. """ - flow_label, node_label, current_priority = repeat(priority)(ctx, pipeline, *args, **kwargs) + flow_label, node_label, current_priority = repeat(priority)(ctx, pipeline) labels = list(pipeline.script.get(flow_label, {})) if node_label not in labels: @@ -149,9 +147,9 @@ def forward(priority: Optional[float] = None, cyclicality_flag: bool = True) -> (e.g the element with `index = len(labels)` has `index = 0`). Defaults to `True`. """ - def forward_transition_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> NodeLabel3Type: + def forward_transition_handler(ctx: Context, pipeline: Pipeline) -> NodeLabel3Type: return _get_label_by_index_shifting( - ctx, pipeline, priority, increment_flag=True, cyclicality_flag=cyclicality_flag, *args, **kwargs + ctx, pipeline, priority, increment_flag=True, cyclicality_flag=cyclicality_flag ) return forward_transition_handler @@ -170,9 +168,9 @@ def backward(priority: Optional[float] = None, cyclicality_flag: bool = True) -> (e.g the element with `index = len(labels)` has `index = 0`). Defaults to `True`. """ - def back_transition_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> NodeLabel3Type: + def back_transition_handler(ctx: Context, pipeline: Pipeline) -> NodeLabel3Type: return _get_label_by_index_shifting( - ctx, pipeline, priority, increment_flag=False, cyclicality_flag=cyclicality_flag, *args, **kwargs + ctx, pipeline, priority, increment_flag=False, cyclicality_flag=cyclicality_flag ) return back_transition_handler diff --git a/dff/script/responses/std_responses.py b/dff/script/responses/std_responses.py index 56cd6920e..5bd2d2576 100644 --- a/dff/script/responses/std_responses.py +++ b/dff/script/responses/std_responses.py @@ -23,7 +23,7 @@ def choice(responses: List[Message]): :param responses: A list of responses for random sampling. """ - def choice_response_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs): + def choice_response_handler(ctx: Context, pipeline: Pipeline): return random.choice(responses) return choice_response_handler diff --git a/dff/utils/testing/common.py b/dff/utils/testing/common.py index 5994663b7..b54bce9fe 100644 --- a/dff/utils/testing/common.py +++ b/dff/utils/testing/common.py @@ -4,7 +4,7 @@ This module contains several functions which are used to run demonstrations in tutorials. """ from os import getenv -from typing import Callable, Tuple, Any, Optional +from typing import Callable, Tuple, Optional from uuid import uuid4 from dff.script import Context, Message @@ -31,9 +31,9 @@ def is_interactive_mode() -> bool: # pragma: no cover def check_happy_path( pipeline: Pipeline, - happy_path: Tuple[Tuple[Any, Any], ...], + happy_path: Tuple[Tuple[Message, Message], ...], # This optional argument is used for additional processing of candidate responses and reference responses - response_comparer: Callable[[Any, Any, Context], Optional[str]] = default_comparer, + response_comparer: Callable[[Message, Message, Context], Optional[str]] = default_comparer, printout_enable: bool = True, ): """ @@ -56,6 +56,14 @@ def check_happy_path( if printout_enable: print(f"(user) >>> {repr(request)}") print(f" (bot) <<< {repr(candidate_response)}") + if candidate_response is None: + raise Exception( + f"\n\npipeline = {pipeline.info_dict}\n\n" + f"ctx = {ctx}\n\n" + f"step_id = {step_id}\n" + f"request = {repr(request)}\n" + "Candidate response is None." + ) parsed_response_with_deviation = response_comparer(candidate_response, reference_response, ctx) if parsed_response_with_deviation is not None: raise Exception( diff --git a/docs/source/user_guides/basic_conceptions.rst b/docs/source/user_guides/basic_conceptions.rst index 5344143e0..bdcab9c61 100644 --- a/docs/source/user_guides/basic_conceptions.rst +++ b/docs/source/user_guides/basic_conceptions.rst @@ -204,10 +204,9 @@ For instance, if a user wants to know a schedule, you may need to access a datab import requests ... - def use_api_processing(ctx: Context, _: Pipeline, *args, **kwargs) -> Context: + def use_api_processing(ctx: Context, _: Pipeline): # save to the context field for custom info ctx.misc["api_call_results"] = requests.get("http://schedule.api/day1").json() - return ctx ... node = { RESPONSE: ... @@ -239,7 +238,7 @@ The latter allows you to customize the response based on the specific scenario a .. code-block:: python - def sample_response(ctx: Context, _: Pipeline, *args, **kwargs) -> Message: + def sample_response(ctx: Context, _: Pipeline) -> Message: if ctx.misc["user"] == 'vegan': return Message(text="Here is a list of vegan cafes.") return Message(text="Here is a list of cafes.") @@ -259,7 +258,7 @@ This ensures a smoother user experience even when the bot encounters unexpected .. code-block:: python - def fallback_response(ctx: Context, _: Pipeline, *args, **kwargs) -> Message: + def fallback_response(ctx: Context, _: Pipeline) -> Message: """ Generate a special fallback response depending on the situation. """ @@ -352,7 +351,7 @@ that you may have in your project, using Python docstrings. .. code-block:: python - def fav_kitchen_response(ctx: Context, _: Pipeline, *args, **kwargs) -> Message: + def fav_kitchen_response(ctx: Context, _: Pipeline) -> Message: """ This function returns a user-targeted response depending on the value of the 'kitchen preference' slot. diff --git a/docs/source/user_guides/context_guide.rst b/docs/source/user_guides/context_guide.rst index 4f3f2dbc7..36117ecce 100644 --- a/docs/source/user_guides/context_guide.rst +++ b/docs/source/user_guides/context_guide.rst @@ -34,9 +34,7 @@ Let's consider some of the built-in callback instances to see how the context ca pattern = re.compile("[a-zA-Z]+") - def regexp_condition_handler( - ctx: Context, pipeline: Pipeline, *args, **kwargs - ) -> bool: + def regexp_condition_handler(ctx: Context, pipeline: Pipeline) -> bool: # retrieve the current request request = ctx.last_request if request.text is None: diff --git a/tests/pipeline/test_parallel_processing.py b/tests/pipeline/test_parallel_processing.py new file mode 100644 index 000000000..2e2e5beb3 --- /dev/null +++ b/tests/pipeline/test_parallel_processing.py @@ -0,0 +1,43 @@ +import asyncio + +import pytest + +from dff.script import Message, GLOBAL, RESPONSE, PRE_RESPONSE_PROCESSING, TRANSITIONS, conditions as cnd +from dff.pipeline import Pipeline + + +@pytest.mark.asyncio +async def test_parallel_processing(): + async def fast_processing(ctx, _): + processed_node = ctx.current_node + await asyncio.sleep(1) + processed_node.response = Message(text=f"fast: {processed_node.response.text}") + + async def slow_processing(ctx, _): + processed_node = ctx.current_node + await asyncio.sleep(2) + processed_node.response = Message(text=f"slow: {processed_node.response.text}") + + toy_script = { + GLOBAL: { + PRE_RESPONSE_PROCESSING: { + "first": slow_processing, + "second": fast_processing, + } + }, + "root": {"start": {TRANSITIONS: {"main": cnd.true()}}, "main": {RESPONSE: Message(text="text")}}, + } + + # test sequential processing + pipeline = Pipeline.from_script(toy_script, start_label=("root", "start"), parallelize_processing=False) + + ctx = await pipeline._run_pipeline(Message(), 0) + + assert ctx.last_response.text == "fast: slow: text" + + # test parallel processing + pipeline = Pipeline.from_script(toy_script, start_label=("root", "start"), parallelize_processing=True) + + ctx = await pipeline._run_pipeline(Message(), 0) + + assert ctx.last_response.text == "slow: fast: text" diff --git a/tests/pipeline/test_update_ctx_misc.py b/tests/pipeline/test_update_ctx_misc.py new file mode 100644 index 000000000..e1d5dd046 --- /dev/null +++ b/tests/pipeline/test_update_ctx_misc.py @@ -0,0 +1,34 @@ +import pytest + +from dff.pipeline import Pipeline +from dff.script import Message, RESPONSE, TRANSITIONS + + +@pytest.mark.asyncio +async def test_update_ctx_misc(): + def condition(ctx, _): + return ctx.misc["condition"] + + toy_script = { + "root": { + "start": {TRANSITIONS: {"success": condition}}, + "success": {RESPONSE: Message(text="success"), TRANSITIONS: {"success": condition}}, + "failure": { + RESPONSE: Message(text="failure"), + }, + } + } + + pipeline = Pipeline.from_script(toy_script, ("root", "start"), ("root", "failure")) + + ctx = await pipeline._run_pipeline(Message(), 0, update_ctx_misc={"condition": True}) + + assert ctx.last_response.text == "success" + + ctx = await pipeline._run_pipeline(Message(), 0) + + assert ctx.last_response.text == "success" + + ctx = await pipeline._run_pipeline(Message(), 0, update_ctx_misc={"condition": False}) + + assert ctx.last_response.text == "failure" diff --git a/tests/script/conditions/test_conditions.py b/tests/script/conditions/test_conditions.py index caec048f1..d8087d0e2 100644 --- a/tests/script/conditions/test_conditions.py +++ b/tests/script/conditions/test_conditions.py @@ -51,7 +51,7 @@ def test_conditions(): except TypeError: pass - def failed_cond_func(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> bool: + def failed_cond_func(ctx: Context, pipeline: Pipeline) -> bool: raise ValueError("Failed cnd") assert not cnd.any([failed_cond_func])(ctx, pipeline) diff --git a/tests/script/core/test_actor.py b/tests/script/core/test_actor.py index 5dd14e27a..265ff9f08 100644 --- a/tests/script/core/test_actor.py +++ b/tests/script/core/test_actor.py @@ -1,4 +1,5 @@ # %% +import pytest from dff.pipeline import Pipeline from dff.script import ( TRANSITIONS, @@ -34,21 +35,22 @@ def negative_test(samples, custom_class): raise Exception(f"sample={sample} can not be passed") -def std_func(ctx, actor, *args, **kwargs): +def std_func(ctx, pipeline): pass -def fake_label(ctx: Context, actor, *args, **kwargs): +def fake_label(ctx: Context, pipeline): if not ctx.validation: return ("123", "123", 0) return ("flow", "node1", 1) -def raised_response(ctx: Context, actor, *args, **kwargs): +def raised_response(ctx: Context, pipeline): raise Exception("") -def test_actor(): +@pytest.mark.asyncio +async def test_actor(): try: # fail of start label Pipeline.from_script({"flow": {"node1": {}}}, start_label=("flow1", "node1")) @@ -80,7 +82,7 @@ def test_actor(): start_label=("flow", "node1"), ) ctx = Context() - pipeline.actor(pipeline, ctx) + await pipeline.actor(pipeline, ctx) raise Exception("can not be passed: fail of response returned Callable") except ValueError: pass @@ -99,14 +101,14 @@ def test_actor(): {"flow": {"node1": {TRANSITIONS: {"node1": true()}}}}, start_label=("flow", "node1") ) ctx = Context() - pipeline.actor(pipeline, ctx) + await pipeline.actor(pipeline, ctx) # fake label stability pipeline = Pipeline.from_script( {"flow": {"node1": {TRANSITIONS: {fake_label: true()}}}}, start_label=("flow", "node1") ) ctx = Context() - pipeline.actor(pipeline, ctx) + await pipeline.actor(pipeline, ctx) limit_errors = {} @@ -115,7 +117,7 @@ def test_actor(): def check_call_limit(limit: int = 1, default_value=None, label=""): counter = 0 - def call_limit_handler(ctx: Context, actor, *args, **kwargs): + def call_limit_handler(ctx: Context, pipeline): nonlocal counter counter += 1 if counter > limit: @@ -128,7 +130,8 @@ def call_limit_handler(ctx: Context, actor, *args, **kwargs): return call_limit_handler -def test_call_limit(): +@pytest.mark.asyncio +async def test_call_limit(): script = { GLOBAL: { TRANSITIONS: { @@ -209,15 +212,9 @@ def test_call_limit(): }, } # script = {"flow": {"node1": {TRANSITIONS: {"node1": true()}}}} - ctx = Context() pipeline = Pipeline.from_script(script=script, start_label=("flow1", "node1"), validation_stage=False) for i in range(4): - ctx.add_request(Message(text="req1")) - ctx = pipeline.actor(pipeline, ctx) + await pipeline._run_pipeline(Message(text="req1"), 0) if limit_errors: error_msg = repr(limit_errors) raise Exception(error_msg) - - -if __name__ == "__main__": - test_call_limit() diff --git a/tests/script/core/test_context.py b/tests/script/core/test_context.py index 757839176..f86182afd 100644 --- a/tests/script/core/test_context.py +++ b/tests/script/core/test_context.py @@ -1,7 +1,7 @@ # %% import random -from dff.script import Context, Node, Message +from dff.script import Context, Message def shuffle_dict_keys(dictionary: dict) -> dict: @@ -51,7 +51,6 @@ def test_context(): } assert ctx.misc == {"1001": "11111"} assert ctx.current_node is None - ctx.overwrite_current_node_in_processing(Node(**{"response": Message(text="text")})) ctx.model_dump_json() try: diff --git a/tests/script/core/test_normalization.py b/tests/script/core/test_normalization.py index 6bd443bf6..9fc1c1244 100644 --- a/tests/script/core/test_normalization.py +++ b/tests/script/core/test_normalization.py @@ -18,15 +18,10 @@ from dff.script.labels import repeat from dff.script.conditions import true -from dff.script.core.normalization import ( - normalize_condition, - normalize_label, - normalize_processing, - normalize_response, -) +from dff.script.core.normalization import normalize_condition, normalize_label, normalize_response -def std_func(ctx, actor, *args, **kwargs): +def std_func(ctx, pipeline): pass @@ -41,10 +36,10 @@ def create_env() -> Tuple[Context, Pipeline]: def test_normalize_label(): ctx, actor = create_env() - def true_label_func(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> NodeLabel3Type: + def true_label_func(ctx: Context, pipeline: Pipeline) -> NodeLabel3Type: return ("flow", "node1", 1) - def false_label_func(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> NodeLabel3Type: + def false_label_func(ctx: Context, pipeline: Pipeline) -> NodeLabel3Type: return ("flow", "node2", 1) n_f = normalize_label(true_label_func) @@ -62,10 +57,10 @@ def false_label_func(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> NodeL def test_normalize_condition(): ctx, actor = create_env() - def true_condition_func(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> bool: + def true_condition_func(ctx: Context, pipeline: Pipeline) -> bool: return True - def false_condition_func(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> bool: + def false_condition_func(ctx: Context, pipeline: Pipeline) -> bool: raise Exception("False condition") n_f = normalize_condition(true_condition_func) @@ -90,27 +85,6 @@ def test_normalize_response(): assert callable(normalize_response(Message(text="text"))) -def test_normalize_processing(): - ctx, actor = create_env() - - def true_processing_func(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> Context: - return ctx - - def false_processing_func(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> Context: - raise Exception("False processing") - - n_f = normalize_processing({1: true_processing_func}) - assert callable(n_f) - assert isinstance(n_f(ctx, actor), Context) - n_f = normalize_processing({1: false_processing_func}) - assert isinstance(n_f(ctx, actor), Context) - - # TODO: Add full check for functions - assert callable(normalize_processing({})) - assert callable(normalize_processing({1: std_func})) - assert callable(normalize_processing({1: std_func, 2: std_func})) - - def test_normalize_keywords(): node_template = { TRANSITIONS: {"node": std_func}, diff --git a/tests/script/core/test_script.py b/tests/script/core/test_script.py index 3d2f1840d..87d82f863 100644 --- a/tests/script/core/test_script.py +++ b/tests/script/core/test_script.py @@ -34,7 +34,7 @@ def negative_test(samples, custom_class): raise Exception(f"sample={sample} can not be passed") -def std_func(ctx, actor, *args, **kwargs): +def std_func(ctx, pipeline): pass diff --git a/tests/tutorials/test_format.py b/tests/tutorials/test_format.py index 772a059bd..d63b19c3f 100644 --- a/tests/tutorials/test_format.py +++ b/tests/tutorials/test_format.py @@ -14,7 +14,7 @@ ] docstring_start_pattern = re.compile(r'# %% \[markdown\]\n"""\n#(?: .*:)? \d+\. .*\n(?:\n[\S\s]*)?"""(?: # .*)?\n') -comment_start_pattern = re.compile(r'# %% \[markdown\]\n# #(?: .*:)? \d+\. .*\n#(?:\n# [\S\s]*)?') +comment_start_pattern = re.compile(r"# %% \[markdown\]\n# #(?: .*:)? \d+\. .*\n#(?:\n# [\S\s]*)?") def regexp_format_checker(dff_tutorial_py_file: pathlib.Path): diff --git a/tutorials/messengers/telegram/5_conditions_with_media.py b/tutorials/messengers/telegram/5_conditions_with_media.py index 7f80ae3c6..144ef9c32 100644 --- a/tutorials/messengers/telegram/5_conditions_with_media.py +++ b/tutorials/messengers/telegram/5_conditions_with_media.py @@ -165,24 +165,24 @@ def extract_data(ctx: Context, _: Pipeline): # A function to extract data with message = ctx.last_request if message is None: - return ctx + return update = getattr(message, "update", None) if update is None: - return ctx + return if not isinstance(update, Message): - return ctx + return if ( # check attachments in update properties not update.photo and not (update.document and update.document.mime_type == "image/jpeg") ): - return ctx + return photo = update.document or update.photo[-1] file = interface.messenger.get_file(photo.file_id) result = interface.messenger.download_file(file.file_path) with open("photo.jpg", "wb+") as new_file: new_file.write(result) - return ctx + return # %% diff --git a/tutorials/pipeline/3_pipeline_dict_with_services_basic.py b/tutorials/pipeline/3_pipeline_dict_with_services_basic.py index dbe7fb1e5..131b62523 100644 --- a/tutorials/pipeline/3_pipeline_dict_with_services_basic.py +++ b/tutorials/pipeline/3_pipeline_dict_with_services_basic.py @@ -39,7 +39,7 @@ On pipeline execution services from `services` list are run without difference between pre- and postprocessors. -Actor constant "ACTOR" should also be present among services. +Actor constant "ACTOR" is required to be passed as one of the services. ServiceBuilder object can be defined either with callable (see tutorial 2) or with dict / object. It should contain `handler` - a ServiceBuilder object. @@ -48,7 +48,7 @@ for most cases `run` method should be used. It starts pipeline asynchronously and connects to provided messenger interface. -Here pipeline contains 4 services, +Here, the pipeline contains 4 services, defined in 4 different ways with different signatures. """ diff --git a/tutorials/pipeline/5_asynchronous_groups_and_services_basic.py b/tutorials/pipeline/5_asynchronous_groups_and_services_basic.py index 4d25fa7b9..77920ed15 100644 --- a/tutorials/pipeline/5_asynchronous_groups_and_services_basic.py +++ b/tutorials/pipeline/5_asynchronous_groups_and_services_basic.py @@ -37,7 +37,7 @@ each of them should sleep for 0.01 of a second. However, as the group is asynchronous, it is being executed for 0.01 of a second in total. -Service group `pipeline` can't be asynchronous because `actor` is synchronous. +Service group can be synchronous or asynchronous. """ diff --git a/tutorials/pipeline/5_asynchronous_groups_and_services_full.py b/tutorials/pipeline/5_asynchronous_groups_and_services_full.py index 137c02989..a3b7e5e80 100644 --- a/tutorials/pipeline/5_asynchronous_groups_and_services_full.py +++ b/tutorials/pipeline/5_asynchronous_groups_and_services_full.py @@ -52,7 +52,7 @@ the service becomes asynchronous, and if set, it is used instead. If service can not be asynchronous, but is marked asynchronous, an exception is thrown. -NB! ACTOR service is always synchronous. +ACTOR service is asynchronous. The timeout field only works for asynchronous services and service groups. If service execution takes more time than timeout, diff --git a/tutorials/script/core/2_conditions.py b/tutorials/script/core/2_conditions.py index d8838b9f4..fa6b1b8ec 100644 --- a/tutorials/script/core/2_conditions.py +++ b/tutorials/script/core/2_conditions.py @@ -36,7 +36,7 @@ See tutorial 1 of pipeline (pipeline/1_basics) to learn more about Actor. Condition functions have signature - def func(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> bool + def func(ctx: Context, pipeline: Pipeline) -> bool Out of the box `dff.script.conditions` offers the following options for setting conditions: @@ -64,7 +64,7 @@ def func(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> bool For example function ``` -def always_true_condition(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> bool: +def always_true_condition(ctx: Context, pipeline: Pipeline) -> bool: return True ``` always returns `True` and `always_true_condition` function @@ -75,7 +75,7 @@ def always_true_condition(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> # %% -def hi_lower_case_condition(ctx: Context, _: Pipeline, *args, **kwargs) -> bool: +def hi_lower_case_condition(ctx: Context, _: Pipeline) -> bool: request = ctx.last_request # Returns True if `hi` in both uppercase and lowercase # letters is contained in the user request. @@ -84,9 +84,7 @@ def hi_lower_case_condition(ctx: Context, _: Pipeline, *args, **kwargs) -> bool: return "hi" in request.text.lower() -def complex_user_answer_condition( - ctx: Context, _: Pipeline, *args, **kwargs -) -> bool: +def complex_user_answer_condition(ctx: Context, _: Pipeline) -> bool: request = ctx.last_request # The user request can be anything. if request is None or request.misc is None: @@ -96,9 +94,7 @@ def complex_user_answer_condition( def predetermined_condition(condition: bool): # Wrapper for internal condition function. - def internal_condition_function( - ctx: Context, _: Pipeline, *args, **kwargs - ) -> bool: + def internal_condition_function(ctx: Context, _: Pipeline) -> bool: # It always returns `condition`. return condition diff --git a/tutorials/script/core/3_responses.py b/tutorials/script/core/3_responses.py index 6d42b3b2b..9218d1ed7 100644 --- a/tutorials/script/core/3_responses.py +++ b/tutorials/script/core/3_responses.py @@ -34,7 +34,7 @@ * Callable objects. If the object is callable it must have a special signature: - func(ctx: Context, pipeline: Pipeline, *args, **kwargs) -> Any + func(ctx: Context, pipeline: Pipeline) -> Message * *Message objects. If the object is *Message it will be returned by the agent as a response. @@ -45,9 +45,7 @@ # %% -def cannot_talk_about_topic_response( - ctx: Context, _: Pipeline, *args, **kwargs -) -> Message: +def cannot_talk_about_topic_response(ctx: Context, _: Pipeline) -> Message: request = ctx.last_request if request is None or request.text is None: topic = None @@ -63,7 +61,7 @@ def cannot_talk_about_topic_response( def upper_case_response(response: Message): # wrapper for internal response function - def func(_: Context, __: Pipeline, *args, **kwargs) -> Message: + def func(_: Context, __: Pipeline) -> Message: if response.text is not None: response.text = response.text.upper() return response @@ -71,9 +69,7 @@ def func(_: Context, __: Pipeline, *args, **kwargs) -> Message: return func -def fallback_trace_response( - ctx: Context, _: Pipeline, *args, **kwargs -) -> Message: +def fallback_trace_response(ctx: Context, _: Pipeline) -> Message: return Message( misc={ "previous_node": list(ctx.labels.values())[-2], diff --git a/tutorials/script/core/4_transitions.py b/tutorials/script/core/4_transitions.py index 1a1ece538..2617bdff9 100644 --- a/tutorials/script/core/4_transitions.py +++ b/tutorials/script/core/4_transitions.py @@ -41,14 +41,12 @@ # %% -def greeting_flow_n2_transition( - _: Context, __: Pipeline, *args, **kwargs -) -> NodeLabel3Type: +def greeting_flow_n2_transition(_: Context, __: Pipeline) -> NodeLabel3Type: return ("greeting_flow", "node2", 1.0) def high_priority_node_transition(flow_label, label): - def transition(_: Context, __: Pipeline, *args, **kwargs) -> NodeLabel3Type: + def transition(_: Context, __: Pipeline) -> NodeLabel3Type: return (flow_label, label, 2.0) return transition diff --git a/tutorials/script/core/6_context_serialization.py b/tutorials/script/core/6_context_serialization.py index bcf6cb7e2..855ab9d07 100644 --- a/tutorials/script/core/6_context_serialization.py +++ b/tutorials/script/core/6_context_serialization.py @@ -29,7 +29,7 @@ # %% -def response_handler(ctx: Context, _: Pipeline, *args, **kwargs) -> Message: +def response_handler(ctx: Context, _: Pipeline) -> Message: return Message(text=f"answer {len(ctx.requests)}") diff --git a/tutorials/script/core/7_pre_response_processing.py b/tutorials/script/core/7_pre_response_processing.py index 38269b6a6..3f7548a29 100644 --- a/tutorials/script/core/7_pre_response_processing.py +++ b/tutorials/script/core/7_pre_response_processing.py @@ -38,15 +38,11 @@ # %% def add_prefix(prefix): - def add_prefix_processing( - ctx: Context, _: Pipeline, *args, **kwargs - ) -> Context: + def add_prefix_processing(ctx: Context, _: Pipeline): processed_node = ctx.current_node processed_node.response = Message( text=f"{prefix}: {processed_node.response.text}" ) - ctx.overwrite_current_node_in_processing(processed_node) - return ctx return add_prefix_processing diff --git a/tutorials/script/core/8_misc.py b/tutorials/script/core/8_misc.py index ac5eb813d..55b582a2e 100644 --- a/tutorials/script/core/8_misc.py +++ b/tutorials/script/core/8_misc.py @@ -33,7 +33,7 @@ # %% -def custom_response(ctx: Context, _: Pipeline, *args, **kwargs) -> Message: +def custom_response(ctx: Context, _: Pipeline) -> Message: if ctx.validation: return Message() current_node = ctx.current_node diff --git a/tutorials/script/core/9_pre_transitions_processing.py b/tutorials/script/core/9_pre_transitions_processing.py index b875b4eb0..259b02e60 100644 --- a/tutorials/script/core/9_pre_transitions_processing.py +++ b/tutorials/script/core/9_pre_transitions_processing.py @@ -34,24 +34,17 @@ # %% -def save_previous_node_response_to_ctx_processing( - ctx: Context, _: Pipeline, *args, **kwargs -) -> Context: +def save_previous_node_response(ctx: Context, _: Pipeline): processed_node = ctx.current_node ctx.misc["previous_node_response"] = processed_node.response - return ctx -def get_previous_node_response_for_response_processing( - ctx: Context, _: Pipeline, *args, **kwargs -) -> Context: +def prepend_previous_node_response(ctx: Context, _: Pipeline): processed_node = ctx.current_node processed_node.response = Message( text=f"previous={ctx.misc['previous_node_response'].text}:" f" current={processed_node.response.text}" ) - ctx.overwrite_current_node_in_processing(processed_node) - return ctx # %% @@ -66,10 +59,10 @@ def get_previous_node_response_for_response_processing( }, GLOBAL: { PRE_RESPONSE_PROCESSING: { - "proc_name_1": get_previous_node_response_for_response_processing + "proc_name_1": prepend_previous_node_response }, PRE_TRANSITIONS_PROCESSING: { - "proc_name_1": save_previous_node_response_to_ctx_processing + "proc_name_1": save_previous_node_response }, TRANSITIONS: {lbl.forward(0.1): cnd.true()}, },