Skip to content

Commit

Permalink
move run_response from Node object to Actor object; execute response …
Browse files Browse the repository at this point in the history
…functions asynchronously
  • Loading branch information
ruthenian8 committed Oct 13, 2023
1 parent a110e62 commit bd67607
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 19 deletions.
35 changes: 27 additions & 8 deletions dff/pipeline/pipeline/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ async def __call__(
await self._run_handlers(ctx, pipeline, ActorStage.RUN_PRE_RESPONSE_PROCESSING, *args, **kwargs)

# create response
ctx.framework_states["actor"]["response"] = ctx.framework_states["actor"][
"pre_response_processed_node"
].run_response(ctx, pipeline, *args, **kwargs)
ctx.framework_states["actor"]["response"] = await self.run_response(
ctx.framework_states["actor"]["pre_response_processed_node"].response, ctx, pipeline, *args, **kwargs
)
await self._run_handlers(ctx, pipeline, ActorStage.CREATE_RESPONSE, *args, **kwargs)
ctx.add_response(ctx.framework_states["actor"]["response"])

Expand Down Expand Up @@ -259,9 +259,26 @@ def _overwrite_node(
overwritten_node.transitions = current_node.transitions
return overwritten_node

async def _run_processing_parallel(self, processing: dict, ctx: Context, pipeline: Pipeline) -> Context:
async def run_response(
self,
response: Optional[Union[Message, Callable[..., Message]]],
ctx: Context,
pipeline: Pipeline,
*args,
**kwargs,
) -> Context:
"""
Executes the normalized response.
See details in the :py:func:`~normalize_response` function of `normalization.py`.
"""
response = normalize_response(response)
return await wrap_sync_function_in_async(response, ctx, pipeline, *args, **kwargs)

async def _run_processing_parallel(
self, processing: dict, ctx: Context, pipeline: Pipeline, *args, **kwargs
) -> Context:
results = await asyncio.gather(
*[wrap_sync_function_in_async(func, ctx, pipeline) for func in processing.values()],
*[wrap_sync_function_in_async(func, ctx, pipeline, *args, **kwargs) for func in processing.values()],
return_exceptions=True,
)
for exc, (processing_name, processing_func) in zip(results, processing.items()):
Expand All @@ -272,10 +289,12 @@ async def _run_processing_parallel(self, processing: dict, ctx: Context, pipelin
)
return ctx

async def _run_processing_sequential(self, processing: dict, ctx: Context, pipeline: Pipeline) -> Context:
async def _run_processing_sequential(
self, processing: dict, ctx: Context, pipeline: Pipeline, *args, **kwargs
) -> Context:
for processing_name, processing_func in processing.items():
try:
ctx = await wrap_sync_function_in_async(processing_func, ctx, pipeline)
ctx = await wrap_sync_function_in_async(processing_func, ctx, pipeline, *args, **kwargs)
except Exception as exc:
logger.error(
f"Exception {exc} for processing_name={processing_name} and processing_func={processing_func}",
Expand Down Expand Up @@ -393,7 +412,7 @@ def validate_script(self, pipeline: Pipeline, verbose: bool = True):
# validate responsing
response_func = normalize_response(node.response)
try:
response_result = response_func(ctx, pipeline)
response_result = asyncio.run(wrap_sync_function_in_async(response_func, ctx, pipeline))
if not isinstance(response_result, Message):
msg = (
"Expected type of response_result is `Message`.\n"
Expand Down
4 changes: 2 additions & 2 deletions dff/script/core/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""
import logging

from typing import Union, Callable, Any, Dict, Optional, ForwardRef
from typing import Union, Callable, Optional, ForwardRef

from .keywords import Keywords
from .context import Context
Expand Down Expand Up @@ -100,7 +100,7 @@ def normalize_response(response: Optional[Union[Message, Callable[..., Message]]
else:
raise TypeError(type(response))

def response_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs):
async def response_handler(ctx: Context, pipeline: Pipeline, *args, **kwargs):
return result

return response_handler
10 changes: 1 addition & 9 deletions dff/script/core/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .types import LabelType, NodeLabelType, ConditionType, NodeLabel3Type
from .message import Message
from .keywords import Keywords
from .normalization import normalize_response, normalize_condition, normalize_label, validate_call
from .normalization import normalize_condition, normalize_label, validate_call
from typing import ForwardRef

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -53,14 +53,6 @@ def normalize_transitions(
}
return transitions

def run_response(self, ctx: Context, pipeline: Pipeline, *args, **kwargs) -> Context:
"""
Executes the normalized response.
See details in the :py:func:`~normalize_response` function of `normalization.py`.
"""
response = normalize_response(self.response)
return response(ctx, pipeline, *args, **kwargs)


class Script(BaseModel, extra="forbid"):
"""
Expand Down

0 comments on commit bd67607

Please sign in to comment.