diff --git a/dff/pipeline/pipeline/actor.py b/dff/pipeline/pipeline/actor.py index ec140ee57..4e653b1af 100644 --- a/dff/pipeline/pipeline/actor.py +++ b/dff/pipeline/pipeline/actor.py @@ -144,9 +144,9 @@ async def __call__( await self._run_handlers(ctx, pipeline, ActorStage.RUN_PRE_RESPONSE_PROCESSING, *args, **kwargs) # create response - ctx.framework_states["actor"]["response"] = ctx.framework_states["actor"][ - "pre_response_processed_node" - ].run_response(ctx, pipeline, *args, **kwargs) + ctx.framework_states["actor"]["response"] = await self.run_response( + ctx.framework_states["actor"]["pre_response_processed_node"].response, ctx, pipeline, *args, **kwargs + ) await self._run_handlers(ctx, pipeline, ActorStage.CREATE_RESPONSE, *args, **kwargs) ctx.add_response(ctx.framework_states["actor"]["response"]) @@ -259,9 +259,26 @@ def _overwrite_node( overwritten_node.transitions = current_node.transitions return overwritten_node - async def _run_processing_parallel(self, processing: dict, ctx: Context, pipeline: Pipeline) -> Context: + async def run_response( + self, + response: Optional[Union[Message, Callable[..., Message]]], + ctx: Context, + pipeline: Pipeline, + *args, + **kwargs, + ) -> Context: + """ + Executes the normalized response. + See details in the :py:func:`~normalize_response` function of `normalization.py`. + """ + response = normalize_response(response) + return await wrap_sync_function_in_async(response, ctx, pipeline, *args, **kwargs) + + async def _run_processing_parallel( + self, processing: dict, ctx: Context, pipeline: Pipeline, *args, **kwargs + ) -> Context: results = await asyncio.gather( - *[wrap_sync_function_in_async(func, ctx, pipeline) for func in processing.values()], + *[wrap_sync_function_in_async(func, ctx, pipeline, *args, **kwargs) for func in processing.values()], return_exceptions=True, ) for exc, (processing_name, processing_func) in zip(results, processing.items()): @@ -272,10 +289,12 @@ async def _run_processing_parallel(self, processing: dict, ctx: Context, pipelin ) return ctx - async def _run_processing_sequential(self, processing: dict, ctx: Context, pipeline: Pipeline) -> Context: + async def _run_processing_sequential( + self, processing: dict, ctx: Context, pipeline: Pipeline, *args, **kwargs + ) -> Context: for processing_name, processing_func in processing.items(): try: - ctx = await wrap_sync_function_in_async(processing_func, ctx, pipeline) + ctx = await wrap_sync_function_in_async(processing_func, ctx, pipeline, *args, **kwargs) except Exception as exc: logger.error( f"Exception {exc} for processing_name={processing_name} and processing_func={processing_func}", @@ -393,7 +412,7 @@ def validate_script(self, pipeline: Pipeline, verbose: bool = True): # validate responsing response_func = normalize_response(node.response) try: - response_result = response_func(ctx, pipeline) + response_result = asyncio.run(wrap_sync_function_in_async(response_func, ctx, pipeline)) if not isinstance(response_result, Message): msg = ( "Expected type of response_result is `Message`.\n" diff --git a/dff/script/core/normalization.py b/dff/script/core/normalization.py index ecabffa7c..63739d17e 100644 --- a/dff/script/core/normalization.py +++ b/dff/script/core/normalization.py @@ -7,7 +7,7 @@ """ import logging -from typing import Union, Callable, Any, Dict, Optional, ForwardRef +from typing import Union, Callable, Optional, ForwardRef from .keywords import Keywords from .context import Context @@ -100,7 +100,7 @@ def normalize_response(response: Optional[Union[Message, Callable[..., Message]] else: raise TypeError(type(response)) - def response_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs): + async def response_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs): return result return response_handler diff --git a/dff/script/core/script.py b/dff/script/core/script.py index 323a66d1e..0c6744278 100644 --- a/dff/script/core/script.py +++ b/dff/script/core/script.py @@ -15,7 +15,7 @@ from .types import LabelType, NodeLabelType, ConditionType, NodeLabel3Type from .message import Message from .keywords import Keywords -from .normalization import normalize_response, normalize_condition, normalize_label, validate_call +from .normalization import normalize_condition, normalize_label, validate_call from typing import ForwardRef logger = logging.getLogger(__name__) @@ -53,14 +53,6 @@ def normalize_transitions( } return transitions - def run_response(self, ctx: Context, pipeline: Pipeline, *args, **kwargs) -> Context: - """ - Executes the normalized response. - See details in the :py:func:`~normalize_response` function of `normalization.py`. - """ - response = normalize_response(self.response) - return response(ctx, pipeline, *args, **kwargs) - class Script(BaseModel, extra="forbid"): """