From 25a2eba81c389583209334dd03931bd015b444fb Mon Sep 17 00:00:00 2001 From: pseusys Date: Fri, 15 Dec 2023 10:34:23 +0100 Subject: [PATCH] typing + linting fixed --- dff/script/slots/conditions.py | 10 ++- dff/script/slots/forms.py | 14 +++-- dff/script/slots/handlers.py | 8 ++- dff/script/slots/processing.py | 16 ++--- dff/script/slots/response.py | 6 +- dff/script/slots/types.py | 51 ++++++++-------- tutorials/slots/1_basic_example.py | 78 ++++++++++++++++++------ tutorials/slots/2_form_example.py | 87 +++++++++++++++++++++------ tutorials/slots/3_handlers_example.py | 71 +++++++++++++++++----- 9 files changed, 239 insertions(+), 102 deletions(-) diff --git a/dff/script/slots/conditions.py b/dff/script/slots/conditions.py index e9342c196..4b254ad7d 100644 --- a/dff/script/slots/conditions.py +++ b/dff/script/slots/conditions.py @@ -21,11 +21,9 @@ def check_slot_state(ctx: Context, _: Pipeline) -> bool: return check_slot_state -def is_set_all(paths: list): - cond = all_condition(*[slot_extracted_condition(path) for path in paths]) - return cond +def is_set_all(paths: list) -> StartConditionCheckerFunction: + return all_condition(*[slot_extracted_condition(path) for path in paths]) -def is_set_any(paths: list): - cond = any_condition(*[slot_extracted_condition(path) for path in paths]) - return cond +def is_set_any(paths: list) -> StartConditionCheckerFunction: + return any_condition(*[slot_extracted_condition(path) for path in paths]) diff --git a/dff/script/slots/forms.py b/dff/script/slots/forms.py index 31fcfe4d2..15a73e7c6 100644 --- a/dff/script/slots/forms.py +++ b/dff/script/slots/forms.py @@ -64,7 +64,9 @@ class FormPolicy(BaseModel): allowed_repeats: int = Field(default=0, gt=-1) node_cache: Dict[NodeLabel2Type, int] = Field(default_factory=Counter) - def __init__(self, name: str, mapping: Dict[str, List[NodeLabel2Type]], *, allowed_repeats: int = 0, **data): + def __init__( + self, name: str, mapping: Dict[str, List[NodeLabel2Type]], *, allowed_repeats: int = 0, **data + ) -> None: """ Create a new form. @@ -121,7 +123,7 @@ def to_next_label_inner(ctx: Context, pipeline: Pipeline) -> NodeLabel3Type: return to_next_label_inner @validate_call - def has_state(self, state: FormState): + def has_state(self, state: FormState) -> Callable[[Context, Pipeline], bool]: """ This method produces a dff.core.engine condition that yields `True` if the state of the form equals the passed :class:`~.FormState` or `False` otherwise. @@ -137,7 +139,7 @@ def is_active_inner(ctx: Context, pipeline: Pipeline) -> bool: return is_active_inner @validate_call - def update_state(self, state: Optional[FormState] = None): + def update_state(self, state: Optional[FormState] = None) -> Callable[[Context, Pipeline], None]: """ This method updates the form state that is stored in the context. @@ -148,7 +150,7 @@ def update_state(self, state: Optional[FormState] = None): """ - def update_inner(ctx: Context, pipeline: Pipeline): + def update_inner(ctx: Context, pipeline: Pipeline) -> None: ctx.framework_states.setdefault(FORM_STORAGE_KEY, {}) if state: @@ -165,8 +167,8 @@ def update_inner(ctx: Context, pipeline: Pipeline): return update_inner - def get_values(self): - def get_values_inner(ctx: Context, pipeline: Pipeline): + def get_values(self) -> Callable[[Context, Pipeline], List[Dict[str, Union[str, None]]]]: + def get_values_inner(ctx: Context, pipeline: Pipeline) -> List[Dict[str, Union[str, None]]]: slots = list(self.mapping.keys()) return get_values(ctx, pipeline, slots) diff --git a/dff/script/slots/handlers.py b/dff/script/slots/handlers.py index 912c33ecf..f3851a7b2 100644 --- a/dff/script/slots/handlers.py +++ b/dff/script/slots/handlers.py @@ -4,7 +4,7 @@ This module is for general functions that can be used in processing, conditions, or responses. """ import logging -from typing import Dict, Optional, List +from typing import Any, Dict, Optional, List, Union from dff.script import Context from dff.pipeline import Pipeline @@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) -def extract(ctx: Context, pipeline: Pipeline, slots: Optional[List[str]] = None) -> list: +def extract(ctx: Context, pipeline: Pipeline, slots: Optional[List[str]] = None) -> List[Optional[Any]]: """ Extract the specified slots and return the received values as a list. If the value of a particular slot cannot be extracted, None is included instead. @@ -56,7 +56,9 @@ def extract(ctx: Context, pipeline: Pipeline, slots: Optional[List[str]] = None) return results -def get_values(ctx: Context, pipeline: Pipeline, slots: Optional[List[str]] = None) -> list: +def get_values( + ctx: Context, pipeline: Pipeline, slots: Optional[List[str]] = None +) -> List[Dict[str, Union[str, None]]]: """ Get values of the specified slots, assuming that they have been extracted beforehand. If slot argument is omitted, values of all slots will be returned. diff --git a/dff/script/slots/processing.py b/dff/script/slots/processing.py index 0b056308e..6a765ada5 100644 --- a/dff/script/slots/processing.py +++ b/dff/script/slots/processing.py @@ -4,11 +4,11 @@ This module encapsulates operations that can be done to slots during the processing stage. """ import logging -from typing import Optional, List, Callable +from typing import Awaitable, Optional, List, Callable from pydantic import validate_call -from dff.script import Context, Message +from dff.script import Context from dff.pipeline import Pipeline from .handlers import get_filled_template, extract as extract_handler, unset as unset_handler @@ -17,7 +17,7 @@ @validate_call -def extract(slots: Optional[List[str]]) -> Callable: +def extract(slots: Optional[List[str]]) -> Callable[[Context, Pipeline], Awaitable[None]]: """ Extract slots from a specified list. @@ -25,7 +25,7 @@ def extract(slots: Optional[List[str]]) -> Callable: Names of slots inside groups should be prefixed with group names, separated by '/': profile/username. """ - async def extract_inner(ctx: Context, pipeline: Pipeline): + async def extract_inner(ctx: Context, pipeline: Pipeline) -> None: _ = extract_handler(ctx, pipeline, slots) return @@ -33,8 +33,8 @@ async def extract_inner(ctx: Context, pipeline: Pipeline): @validate_call -def unset(slots: Optional[List[str]] = None): - def unset_inner(ctx: Context, pipeline: Pipeline): +def unset(slots: Optional[List[str]] = None) -> Callable[[Context, Pipeline], None]: + def unset_inner(ctx: Context, pipeline: Pipeline) -> None: unset_handler(ctx, pipeline, slots) return @@ -42,7 +42,7 @@ def unset_inner(ctx: Context, pipeline: Pipeline): @validate_call -def fill_template(slots: Optional[List[str]] = None): +def fill_template(slots: Optional[List[str]] = None) -> Callable[[Context, Pipeline], None]: """ Fill the response template in the current node. Response should be an instance of :py:class:`~.Message`. @@ -51,7 +51,7 @@ def fill_template(slots: Optional[List[str]] = None): :param slots: Slot names to use. If this parameter is omitted, all slots will be used. """ - def fill_inner(ctx: Context, pipeline: Pipeline): + def fill_inner(ctx: Context, pipeline: Pipeline) -> None: # get current node response response = ctx.current_node.response if callable(response): diff --git a/dff/script/slots/response.py b/dff/script/slots/response.py index c15534bee..8210ec0e5 100644 --- a/dff/script/slots/response.py +++ b/dff/script/slots/response.py @@ -4,7 +4,7 @@ This module is for functions that should be executed at the response stage. They produce the response that will be ultimately given to the user. """ -from typing import Optional, List +from typing import Callable, Optional, List from pydantic import validate_call @@ -15,7 +15,7 @@ @validate_call -def fill_template(template: Message, slots: Optional[List[str]] = None): +def fill_template(template: Message, slots: Optional[List[str]] = None) -> Callable[[Context, Pipeline], Message]: """ Fill a template with slot values. Response should be an instance of :py:class:`~.Message` class. @@ -24,7 +24,7 @@ def fill_template(template: Message, slots: Optional[List[str]] = None): :param slots: Slot names to use. If this parameter is omitted, all slots will be used. """ - def fill_inner(ctx: Context, pipeline: Pipeline): + def fill_inner(ctx: Context, pipeline: Pipeline) -> Message: new_template = template.model_copy() new_text = get_filled_template(template.text, ctx, pipeline, slots) new_template.text = new_text diff --git a/dff/script/slots/types.py b/dff/script/slots/types.py index e1a7d6836..fcfc4dabb 100644 --- a/dff/script/slots/types.py +++ b/dff/script/slots/types.py @@ -26,19 +26,19 @@ class BaseSlot(BaseModel, ABC, arbitrary_types_allowed=True): children: Optional[Dict[str, "BaseSlot"]] @field_validator("name", mode="before") - def validate_name(cls, name: str): + def validate_name(cls, name: str) -> str: if "/" in name: raise ValueError("Character `/` cannot be used in slot names.") return name - def __init__(self, name: str, **data): + def __init__(self, name: str, **data) -> None: super().__init__(name=name, **data) - def __deepcopy__(self, *args, **kwargs): + def __deepcopy__(self) -> "BaseSlot": return copy(self) - def __eq__(self, other: "BaseSlot"): - return self.dict(exclude={"name"}) == other.dict(exclude={"name"}) + def __eq__(self, other: "BaseSlot") -> bool: + return self.model_dump(exclude={"name"}) == other.model_dump(exclude={"name"}) def has_children(self) -> bool: return self.children is not None and len(self.children) > 0 @@ -79,7 +79,7 @@ class _GroupSlot(BaseSlot): children: Dict[str, BaseSlot] = Field(default_factory=dict) @field_validator("children", mode="before") - def validate_children(cls, children, values: dict): + def validate_children(cls, children: Iterable, values: Dict[str, Any]) -> Dict[str, BaseSlot]: if not isinstance(children, dict) and isinstance(children, Iterable): children = {child.name: child for child in children} if len(children) == 0: @@ -87,8 +87,8 @@ def validate_children(cls, children, values: dict): raise ValueError(f"Error in slot {name}: group slot should have at least one child or more.") return children - def is_set(self): - def is_set_inner(ctx: Context, pipeline: Pipeline): + def is_set(self) -> Callable[[Context, Pipeline], bool]: + def is_set_inner(ctx: Context, pipeline: Pipeline) -> bool: return all([child.is_set()(ctx, pipeline) for child in self.children.values()]) return is_set_inner @@ -105,14 +105,14 @@ def get_inner(ctx: Context, pipeline: Pipeline) -> Dict[str, Union[str, None]]: return get_inner - def unset_value(self): - def unset_inner(ctx: Context, pipeline: Pipeline): + def unset_value(self) -> Callable[[Context, Pipeline], None]: + def unset_inner(ctx: Context, pipeline: Pipeline) -> None: for child in self.children.values(): child.unset_value()(ctx, pipeline) return unset_inner - def fill_template(self, template: str) -> Callable: + def fill_template(self, template: str) -> Callable[[Context, Pipeline], str]: def fill_inner(ctx: Context, pipeline: Pipeline) -> str: new_template = template for _, child in self.children.items(): @@ -122,7 +122,7 @@ def fill_inner(ctx: Context, pipeline: Pipeline) -> str: return fill_inner - def extract_value(self, ctx: Context, pipeline: Pipeline): + def extract_value(self, ctx: Context, pipeline: Pipeline) -> Any: for child in self.children.values(): _ = child.extract_value(ctx, pipeline) return self.get_value()(ctx, pipeline) @@ -133,6 +133,7 @@ class RootSlot(_GroupSlot): """ Root slot is a universally unique slot group that automatically registers all the other slots and makes them globally available. + """ @staticmethod @@ -148,7 +149,7 @@ def flatten_slot_tree(node: BaseSlot) -> Tuple[Dict[str, BaseSlot], Dict[str, Ba remove_nodes.update(child_remove_nodes) return add_nodes, remove_nodes - def add_slots(self, slots: Union[BaseSlot, Iterable]): + def add_slots(self, slots: Union[BaseSlot, Iterable]) -> None: if isinstance(slots, BaseSlot): add_nodes, _ = self.flatten_slot_tree(slots) self.children.update(add_nodes) @@ -161,7 +162,7 @@ def add_slots(self, slots: Union[BaseSlot, Iterable]): class ChildSlot(BaseSlot): - def __init__(self, *, name, **kwargs): + def __init__(self, *, name, **kwargs) -> None: super().__init__(name=name, **kwargs) root_slot.add_slots(self) @@ -188,8 +189,8 @@ class ValueSlot(ChildSlot): children: None = Field(None) value: Any = None - def is_set(self): - def is_set_inner(ctx: Context, _: Pipeline): + def is_set(self) -> Callable[[Context, Pipeline], bool]: + def is_set_inner(ctx: Context, _: Pipeline) -> bool: return bool(ctx.framework_states.get(SLOT_STORAGE_KEY, {}).get(self.name)) return is_set_inner @@ -200,14 +201,14 @@ def get_inner(ctx: Context, _: Pipeline) -> Union[str, None]: return get_inner - def unset_value(self): - def unset_inner(ctx: Context, _: Pipeline): + def unset_value(self) -> Callable[[Context, Pipeline], None]: + def unset_inner(ctx: Context, _: Pipeline) -> None: ctx.framework_states.setdefault(SLOT_STORAGE_KEY, {}) ctx.framework_states[SLOT_STORAGE_KEY][self.name] = None return unset_inner - def fill_template(self, template: str) -> Callable[[Context, Pipeline], str]: + def fill_template(self, template: str) -> Callable[[Context, Pipeline], Union[str, None]]: """ Value Slot's `fill_template` method does not perform template filling on its own, but allows you to cut corners on some standard operations. E. g., if you include the following snippet in @@ -258,8 +259,8 @@ class RegexpSlot(ValueSlot): regexp: str match_group_idx: int = 0 - def fill_template(self, template: str) -> Callable: - def fill_inner(ctx: Context, pipeline: Pipeline): + def fill_template(self, template: str) -> Callable[[Context, Pipeline], str]: + def fill_inner(ctx: Context, pipeline: Pipeline) -> str: checked_template = super(RegexpSlot, self).fill_template(template)(ctx, pipeline) if checked_template is None: # the check returning None means that an error has occured. return template @@ -269,7 +270,7 @@ def fill_inner(ctx: Context, pipeline: Pipeline): return fill_inner - def extract_value(self, ctx: Context, _: Pipeline): + def extract_value(self, ctx: Context, _: Pipeline) -> Any: search = re.search(self.regexp, ctx.last_request.text) self.value = search.group(self.match_group_idx) if search else None return self.value @@ -284,8 +285,8 @@ class FunctionSlot(ValueSlot): func: Callable[[str], str] - def fill_template(self, template: str) -> Callable: - def fill_inner(ctx: Context, pipeline: Pipeline): + def fill_template(self, template: str) -> Callable[[Context, Pipeline], str]: + def fill_inner(ctx: Context, pipeline: Pipeline) -> str: checked_template = super(FunctionSlot, self).fill_template(template)(ctx, pipeline) if not checked_template: # the check returning None means that an error has occured. return template @@ -295,6 +296,6 @@ def fill_inner(ctx: Context, pipeline: Pipeline): return fill_inner - def extract_value(self, ctx: Context, _: Pipeline): + def extract_value(self, ctx: Context, _: Pipeline) -> Any: self.value = self.func(ctx.last_request.text) return self.value diff --git a/tutorials/slots/1_basic_example.py b/tutorials/slots/1_basic_example.py index c0775717d..248a2161a 100644 --- a/tutorials/slots/1_basic_example.py +++ b/tutorials/slots/1_basic_example.py @@ -26,7 +26,11 @@ from dff.script.slots import response as slot_rsp from dff.script.slots import conditions as slot_cnd -from dff.utils.testing import check_happy_path, is_interactive_mode, run_interactive_mode +from dff.utils.testing import ( + check_happy_path, + is_interactive_mode, + run_interactive_mode, +) # %% [markdown] """ @@ -45,8 +49,14 @@ person_slot = slots.GroupSlot( name="person", children=[ - slots.RegexpSlot(name="username", regexp=r"username is ([a-zA-Z]+)", match_group_idx=1), - slots.RegexpSlot(name="email", regexp=r"email is ([a-z@\.A-Z]+)", match_group_idx=1), + slots.RegexpSlot( + name="username", + regexp=r"username is ([a-zA-Z]+)", + match_group_idx=1, + ), + slots.RegexpSlot( + name="email", regexp=r"email is ([a-z@\.A-Z]+)", match_group_idx=1 + ), ], ) friend_slot = slots.GroupSlot( @@ -69,9 +79,13 @@ GLOBAL: {TRANSITIONS: {("username_flow", "ask"): cnd.regexp(r"^[sS]tart")}}, "username_flow": { LOCAL: { - PRE_TRANSITIONS_PROCESSING: {"get_slot": slot_procs.extract(["person/username"])}, + PRE_TRANSITIONS_PROCESSING: { + "get_slot": slot_procs.extract(["person/username"]) + }, TRANSITIONS: { - ("email_flow", "ask", 1.2): slot_cnd.is_set_all(["person/username"]), + ("email_flow", "ask", 1.2): slot_cnd.is_set_all( + ["person/username"] + ), ("username_flow", "repeat_question", 0.8): cnd.true(), }, }, @@ -79,12 +93,16 @@ RESPONSE: Message(text="Write your username (my username is ...):"), }, "repeat_question": { - RESPONSE: Message(text="Please, type your username again (my username is ...):") + RESPONSE: Message( + text="Please, type your username again (my username is ...):" + ) }, }, "email_flow": { LOCAL: { - PRE_TRANSITIONS_PROCESSING: {"get_slot": slot_procs.extract(["person/email"])}, + PRE_TRANSITIONS_PROCESSING: { + "get_slot": slot_procs.extract(["person/email"]) + }, TRANSITIONS: { ("friend_flow", "ask", 1.2): slot_cnd.is_set_all( ["person/username", "person/email"] @@ -96,12 +114,16 @@ RESPONSE: Message(text="Write your email (my email is ...):"), }, "repeat_question": { - RESPONSE: Message(text="Please, write your email again (my email is ...):") + RESPONSE: Message( + text="Please, write your email again (my email is ...):" + ) }, }, "friend_flow": { LOCAL: { - PRE_TRANSITIONS_PROCESSING: {"get_slots": slot_procs.extract(["friend"])}, + PRE_TRANSITIONS_PROCESSING: { + "get_slots": slot_procs.extract(["friend"]) + }, TRANSITIONS: { ("root", "utter", 1.2): slot_cnd.is_set_any( ["friend/first_name", "friend/last_name"] @@ -109,20 +131,31 @@ ("friend_flow", "repeat_question", 0.8): cnd.true(), }, }, - "ask": {RESPONSE: Message(text="Please, name me one of your friends: (John Doe)")}, + "ask": { + RESPONSE: Message( + text="Please, name me one of your friends: (John Doe)" + ) + }, "repeat_question": { - RESPONSE: Message(text="Please, name me one of your friends again: (John Doe)") + RESPONSE: Message( + text="Please, name me one of your friends again: (John Doe)" + ) }, }, "root": { - "start": {RESPONSE: Message(text=""), TRANSITIONS: {("username_flow", "ask"): cnd.true()}}, + "start": { + RESPONSE: Message(text=""), + TRANSITIONS: {("username_flow", "ask"): cnd.true()}, + }, "fallback": { RESPONSE: Message(text="Finishing query"), TRANSITIONS: {("username_flow", "ask"): cnd.true()}, }, "utter": { RESPONSE: slot_rsp.fill_template( - Message(text="Your friend is called {friend/first_name} {friend/last_name}") + Message( + text="Your friend is called {friend/first_name} {friend/last_name}" + ) ), TRANSITIONS: {("root", "utter_alternative"): cnd.true()}, }, @@ -138,14 +171,23 @@ # %% HAPPY_PATH = [ - (Message(text="hi"), Message(text="Write your username (my username is ...):")), - (Message(text="my username is groot"), Message(text="Write your email (my email is ...):")), + ( + Message(text="hi"), + Message(text="Write your username (my username is ...):"), + ), + ( + Message(text="my username is groot"), + Message(text="Write your email (my email is ...):"), + ), ( Message(text="my email is groot@gmail.com"), Message(text="Please, name me one of your friends: (John Doe)"), ), (Message(text="Bob Page"), Message(text="Your friend is called Bob Page")), - (Message(text="ok"), Message(text="Your username is groot. Your email is groot@gmail.com.")), + ( + Message(text="ok"), + Message(text="Your username is groot. Your email is groot@gmail.com."), + ), (Message(text="ok"), Message(text="Finishing query")), ] @@ -157,7 +199,9 @@ ) if __name__ == "__main__": - check_happy_path(pipeline, HAPPY_PATH) # This is a function for automatic tutorial running + check_happy_path( + pipeline, HAPPY_PATH + ) # This is a function for automatic tutorial running # (testing) with HAPPY_PATH # This runs tutorial in interactive mode if not in IPython env diff --git a/tutorials/slots/2_form_example.py b/tutorials/slots/2_form_example.py index cfb3a7f66..bfe015fe8 100644 --- a/tutorials/slots/2_form_example.py +++ b/tutorials/slots/2_form_example.py @@ -10,11 +10,22 @@ # %% from dff.script import labels as lbl -from dff.script import RESPONSE, TRANSITIONS, PRE_TRANSITIONS_PROCESSING, GLOBAL, LOCAL, Message +from dff.script import ( + RESPONSE, + TRANSITIONS, + PRE_TRANSITIONS_PROCESSING, + GLOBAL, + LOCAL, + Message, +) from dff.pipeline import Pipeline -from dff.utils.testing import check_happy_path, is_interactive_mode, run_interactive_mode +from dff.utils.testing import ( + check_happy_path, + is_interactive_mode, + run_interactive_mode, +) from dff.script import conditions as cnd from dff.script import slots @@ -70,7 +81,9 @@ script = { GLOBAL: { TRANSITIONS: { - restaurant_form.to_next_label(1.1): restaurant_form.has_state(FormState.ACTIVE), + restaurant_form.to_next_label(1.1): restaurant_form.has_state( + FormState.ACTIVE + ), }, PRE_TRANSITIONS_PROCESSING: { "extract_cuisine": slot_procs.extract([restaurant_cuisine.name]), @@ -87,36 +100,52 @@ restaurant_form.has_state(FormState.INACTIVE), ] ), # this transition ensures the form loop can be left - ("restaurant", "form_filled", 0.9): restaurant_form.has_state(FormState.COMPLETE), + ("restaurant", "form_filled", 0.9): restaurant_form.has_state( + FormState.COMPLETE + ), } }, "offer": { RESPONSE: slot_rsp.fill_template( - Message(text="Would you like me to find a {cuisine} cuisine restaurant?") + Message( + text="Would you like me to find a {cuisine} cuisine restaurant?" + ) ), - TRANSITIONS: {lbl.forward(1.1): cnd.regexp(r"[yY]es|[yY]eah|[Oo][Kk]|[Ff]ine")}, + TRANSITIONS: { + lbl.forward(1.1): cnd.regexp(r"[yY]es|[yY]eah|[Oo][Kk]|[Ff]ine") + }, PRE_TRANSITIONS_PROCESSING: { "reset_form": restaurant_form.update_state(FormState.INACTIVE), - "reset_slots": slot_procs.unset([restaurant_address.name, number_of_people.name]), + "reset_slots": slot_procs.unset( + [restaurant_address.name, number_of_people.name] + ), }, # Explicitly resetting form and slot states }, "offer_accepted": { RESPONSE: Message(text="Very well then, processing your request."), PRE_TRANSITIONS_PROCESSING: { - "activate_form": restaurant_form.update_state(slots.FormState.ACTIVE), + "activate_form": restaurant_form.update_state( + slots.FormState.ACTIVE + ), }, }, "form_filled": { RESPONSE: slot_rsp.fill_template( - Message(text="All done, a table for {numberofpeople} has been reserved") + Message( + text="All done, a table for {numberofpeople} has been reserved" + ) ), TRANSITIONS: {("chitchat", "chat_3", 1.1): cnd.true()}, }, "cuisine": { - RESPONSE: Message(text="What kind of cuisine would you like to have?"), + RESPONSE: Message( + text="What kind of cuisine would you like to have?" + ), }, "address": { - RESPONSE: Message(text="In what area would you like to find a restaurant?"), + RESPONSE: Message( + text="In what area would you like to find a restaurant?" + ), }, "number": { RESPONSE: Message(text="How many people would you like to invite?"), @@ -136,10 +165,17 @@ RESPONSE: Message(text="Did you like the latest Star Wars film?"), TRANSITIONS: {lbl.to_fallback(1.1): cnd.true()}, }, - "chat_4": {RESPONSE: Message(text="Who do you think will win the Champions League?")}, + "chat_4": { + RESPONSE: Message( + text="Who do you think will win the Champions League?" + ) + }, }, "root": { - "start": {RESPONSE: Message(text=""), TRANSITIONS: {("chitchat", "chat_1", 2): cnd.true()}}, + "start": { + RESPONSE: Message(text=""), + TRANSITIONS: {("chitchat", "chat_1", 2): cnd.true()}, + }, "fallback": { RESPONSE: Message(text="Nice chatting with you!"), TRANSITIONS: {("chitchat", "chat_1", 2): cnd.true()}, @@ -150,7 +186,10 @@ HAPPY_PATH = [ (Message(text="hi"), Message(text="How's life?")), (Message(text="good"), Message(text="What kind of cuisine do you like?")), - (Message(text="none"), Message(text="Did you like the latest Star Wars film?")), + ( + Message(text="none"), + Message(text="Did you like the latest Star Wars film?"), + ), (Message(text="yes"), Message(text="Nice chatting with you!")), (Message(text="hi"), Message(text="How's life?")), (Message(text="good"), Message(text="What kind of cuisine do you like?")), @@ -158,14 +197,26 @@ Message(text="french cuisine"), Message(text="Would you like me to find a french cuisine restaurant?"), ), - (Message(text="yes"), Message(text="Very well then, processing your request.")), - (Message(text="ok"), Message(text="In what area would you like to find a restaurant?")), - (Message(text="in London"), Message(text="How many people would you like to invite?")), + ( + Message(text="yes"), + Message(text="Very well then, processing your request."), + ), + ( + Message(text="ok"), + Message(text="In what area would you like to find a restaurant?"), + ), + ( + Message(text="in London"), + Message(text="How many people would you like to invite?"), + ), ( Message(text="3 people"), Message(text="All done, a table for 3 has been reserved"), ), - (Message(text="ok"), Message(text="Did you like the latest Star Wars film?")), + ( + Message(text="ok"), + Message(text="Did you like the latest Star Wars film?"), + ), (Message(text="yes"), Message(text="Nice chatting with you!")), ] diff --git a/tutorials/slots/3_handlers_example.py b/tutorials/slots/3_handlers_example.py index 7644254a6..d73561f10 100644 --- a/tutorials/slots/3_handlers_example.py +++ b/tutorials/slots/3_handlers_example.py @@ -23,7 +23,11 @@ from dff.pipeline import Pipeline -from dff.utils.testing import check_happy_path, is_interactive_mode, run_interactive_mode +from dff.utils.testing import ( + check_happy_path, + is_interactive_mode, + run_interactive_mode, +) from dff.script import slots from dff.script.slots import conditions as slot_cnd from dff.script.slots import processing as slot_procs @@ -34,9 +38,17 @@ slots.GroupSlot( name="pet_info", children=[ - slots.RegexpSlot(name="sort", regexp=r"(dog|cat)", match_group_idx=1), - slots.RegexpSlot(name="gender", regexp=r"(she|(?<=[^s])he|^he)", match_group_idx=1), - slots.RegexpSlot(name="behaviour", regexp=r"(good|bad)", match_group_idx=1), + slots.RegexpSlot( + name="sort", regexp=r"(dog|cat)", match_group_idx=1 + ), + slots.RegexpSlot( + name="gender", + regexp=r"(she|(?<=[^s])he|^he)", + match_group_idx=1, + ), + slots.RegexpSlot( + name="behaviour", regexp=r"(good|bad)", match_group_idx=1 + ), ], ) ], @@ -58,7 +70,9 @@ def custom_behaviour_question(ctx: Context, pipeline: Pipeline): template = "Is {pet/pet_info/gender} a good " middle = " or a bad " - new_template = slots.get_filled_template(template, ctx, pipeline, slots=["pet/pet_info/gender"]) + new_template = slots.get_filled_template( + template, ctx, pipeline, slots=["pet/pet_info/gender"] + ) gender = slots.get_values(ctx, pipeline, slots=["pet/pet_info/gender"])[0] if gender is None: new_template = slots.get_filled_template( @@ -92,22 +106,34 @@ def custom_esteem(ctx: Context, pipeline: Pipeline): GLOBAL: {TRANSITIONS: {("pet_flow", "ask"): cnd.regexp(r"^[sS]tart")}}, "pet_flow": { LOCAL: { - PRE_TRANSITIONS_PROCESSING: {"get_slot": slot_procs.extract(["pet/pet_info/sort"])}, + PRE_TRANSITIONS_PROCESSING: { + "get_slot": slot_procs.extract(["pet/pet_info/sort"]) + }, TRANSITIONS: { - ("gender_flow", "ask", 1.2): slot_cnd.is_set_all(["pet/pet_info/sort"]), + ("gender_flow", "ask", 1.2): slot_cnd.is_set_all( + ["pet/pet_info/sort"] + ), ("pet_flow", "repeat_question", 0.8): cnd.true(), }, }, "ask": { - RESPONSE: Message(text="I heard that you have a pet. Is it a cat, or a dog?"), + RESPONSE: Message( + text="I heard that you have a pet. Is it a cat, or a dog?" + ), + }, + "repeat_question": { + RESPONSE: Message(text="Seriously, is it a cat, or a dog?") }, - "repeat_question": {RESPONSE: Message(text="Seriously, is it a cat, or a dog?")}, }, "gender_flow": { LOCAL: { - PRE_TRANSITIONS_PROCESSING: {"get_slot": slot_procs.extract(["pet/pet_info/gender"])}, + PRE_TRANSITIONS_PROCESSING: { + "get_slot": slot_procs.extract(["pet/pet_info/gender"]) + }, TRANSITIONS: { - ("behaviour_flow", "ask", 1.2): slot_cnd.is_set_all(["pet/pet_info/gender"]), + ("behaviour_flow", "ask", 1.2): slot_cnd.is_set_all( + ["pet/pet_info/gender"] + ), ("gender_flow", "repeat_question", 0.8): cnd.true(), }, }, @@ -115,7 +141,9 @@ def custom_esteem(ctx: Context, pipeline: Pipeline): RESPONSE: Message(text="Great! Is it a he, or a she?"), }, "repeat_question": { - RESPONSE: Message(text="I mean, is it a he, or a she? Name whatever is closer.") + RESPONSE: Message( + text="I mean, is it a he, or a she? Name whatever is closer." + ) }, }, "behaviour_flow": { @@ -124,7 +152,9 @@ def custom_esteem(ctx: Context, pipeline: Pipeline): "get_slot": slot_procs.extract(["pet/pet_info/behaviour"]) }, TRANSITIONS: { - ("root", "esteem", 1.2): slot_cnd.is_set_all(["pet/pet_info/behaviour"]), + ("root", "esteem", 1.2): slot_cnd.is_set_all( + ["pet/pet_info/behaviour"] + ), ("behaviour_flow", "repeat_question", 0.8): cnd.true(), }, }, @@ -132,7 +162,10 @@ def custom_esteem(ctx: Context, pipeline: Pipeline): "repeat_question": {RESPONSE: custom_behaviour_question}, }, "root": { - "start": {RESPONSE: Message(text=""), TRANSITIONS: {("pet_flow", "ask"): cnd.true()}}, + "start": { + RESPONSE: Message(text=""), + TRANSITIONS: {("pet_flow", "ask"): cnd.true()}, + }, "fallback": { RESPONSE: Message(text="It's been a nice talk! See you."), TRANSITIONS: {("pet_flow", "ask"): cnd.true()}, @@ -146,12 +179,18 @@ def custom_esteem(ctx: Context, pipeline: Pipeline): } HAPPY_PATH = [ - (Message(text="hi"), Message(text="I heard that you have a pet. Is it a cat, or a dog?")), + ( + Message(text="hi"), + Message(text="I heard that you have a pet. Is it a cat, or a dog?"), + ), (Message(text="it is a dog"), Message(text="Great! Is it a he, or a she?")), (Message(text="he"), Message(text="Is he a good boy or a bad boy?")), (Message(text="it's bad"), Message(text="Sorry to hear that.")), (Message(text="ok"), Message(text="It's been a nice talk! See you.")), - (Message(text="ok"), Message(text="I heard that you have a pet. Is it a cat, or a dog?")), + ( + Message(text="ok"), + Message(text="I heard that you have a pet. Is it a cat, or a dog?"), + ), (Message(text="a CAT"), Message(text="Seriously, is it a cat, or a dog?")), (Message(text="it's a cat"), Message(text="Great! Is it a he, or a she?")), (Message(text="she"), Message(text="Is she a good girl or a bad girl?")),