Skip to content

Commit

Permalink
typing + linting fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
pseusys committed Dec 15, 2023
1 parent 18d445b commit 25a2eba
Show file tree
Hide file tree
Showing 9 changed files with 239 additions and 102 deletions.
10 changes: 4 additions & 6 deletions dff/script/slots/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,9 @@ def check_slot_state(ctx: Context, _: Pipeline) -> bool:
return check_slot_state


def is_set_all(paths: list):
cond = all_condition(*[slot_extracted_condition(path) for path in paths])
return cond
def is_set_all(paths: list) -> StartConditionCheckerFunction:
return all_condition(*[slot_extracted_condition(path) for path in paths])


def is_set_any(paths: list):
cond = any_condition(*[slot_extracted_condition(path) for path in paths])
return cond
def is_set_any(paths: list) -> StartConditionCheckerFunction:
return any_condition(*[slot_extracted_condition(path) for path in paths])
14 changes: 8 additions & 6 deletions dff/script/slots/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ class FormPolicy(BaseModel):
allowed_repeats: int = Field(default=0, gt=-1)
node_cache: Dict[NodeLabel2Type, int] = Field(default_factory=Counter)

def __init__(self, name: str, mapping: Dict[str, List[NodeLabel2Type]], *, allowed_repeats: int = 0, **data):
def __init__(
self, name: str, mapping: Dict[str, List[NodeLabel2Type]], *, allowed_repeats: int = 0, **data
) -> None:
"""
Create a new form.
Expand Down Expand Up @@ -121,7 +123,7 @@ def to_next_label_inner(ctx: Context, pipeline: Pipeline) -> NodeLabel3Type:
return to_next_label_inner

@validate_call
def has_state(self, state: FormState):
def has_state(self, state: FormState) -> Callable[[Context, Pipeline], bool]:
"""
This method produces a dff.core.engine condition that yields `True` if the state of the form
equals the passed :class:`~.FormState` or `False` otherwise.
Expand All @@ -137,7 +139,7 @@ def is_active_inner(ctx: Context, pipeline: Pipeline) -> bool:
return is_active_inner

@validate_call
def update_state(self, state: Optional[FormState] = None):
def update_state(self, state: Optional[FormState] = None) -> Callable[[Context, Pipeline], None]:
"""
This method updates the form state that is stored in the context.
Expand All @@ -148,7 +150,7 @@ def update_state(self, state: Optional[FormState] = None):
"""

def update_inner(ctx: Context, pipeline: Pipeline):
def update_inner(ctx: Context, pipeline: Pipeline) -> None:
ctx.framework_states.setdefault(FORM_STORAGE_KEY, {})

if state:
Expand All @@ -165,8 +167,8 @@ def update_inner(ctx: Context, pipeline: Pipeline):

return update_inner

def get_values(self):
def get_values_inner(ctx: Context, pipeline: Pipeline):
def get_values(self) -> Callable[[Context, Pipeline], List[Dict[str, Union[str, None]]]]:
def get_values_inner(ctx: Context, pipeline: Pipeline) -> List[Dict[str, Union[str, None]]]:
slots = list(self.mapping.keys())
return get_values(ctx, pipeline, slots)

Expand Down
8 changes: 5 additions & 3 deletions dff/script/slots/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
This module is for general functions that can be used in processing, conditions, or responses.
"""
import logging
from typing import Dict, Optional, List
from typing import Any, Dict, Optional, List, Union

from dff.script import Context
from dff.pipeline import Pipeline
Expand All @@ -14,7 +14,7 @@
logger = logging.getLogger(__name__)


def extract(ctx: Context, pipeline: Pipeline, slots: Optional[List[str]] = None) -> list:
def extract(ctx: Context, pipeline: Pipeline, slots: Optional[List[str]] = None) -> List[Optional[Any]]:
"""
Extract the specified slots and return the received values as a list.
If the value of a particular slot cannot be extracted, None is included instead.
Expand Down Expand Up @@ -56,7 +56,9 @@ def extract(ctx: Context, pipeline: Pipeline, slots: Optional[List[str]] = None)
return results


def get_values(ctx: Context, pipeline: Pipeline, slots: Optional[List[str]] = None) -> list:
def get_values(
ctx: Context, pipeline: Pipeline, slots: Optional[List[str]] = None
) -> List[Dict[str, Union[str, None]]]:
"""
Get values of the specified slots, assuming that they have been extracted beforehand.
If slot argument is omitted, values of all slots will be returned.
Expand Down
16 changes: 8 additions & 8 deletions dff/script/slots/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
This module encapsulates operations that can be done to slots during the processing stage.
"""
import logging
from typing import Optional, List, Callable
from typing import Awaitable, Optional, List, Callable

from pydantic import validate_call

from dff.script import Context, Message
from dff.script import Context
from dff.pipeline import Pipeline

from .handlers import get_filled_template, extract as extract_handler, unset as unset_handler
Expand All @@ -17,32 +17,32 @@


@validate_call
def extract(slots: Optional[List[str]]) -> Callable:
def extract(slots: Optional[List[str]]) -> Callable[[Context, Pipeline], Awaitable[None]]:
"""
Extract slots from a specified list.
:param slots: List of slot names to extract.
Names of slots inside groups should be prefixed with group names, separated by '/': profile/username.
"""

async def extract_inner(ctx: Context, pipeline: Pipeline):
async def extract_inner(ctx: Context, pipeline: Pipeline) -> None:
_ = extract_handler(ctx, pipeline, slots)
return

return extract_inner


@validate_call
def unset(slots: Optional[List[str]] = None):
def unset_inner(ctx: Context, pipeline: Pipeline):
def unset(slots: Optional[List[str]] = None) -> Callable[[Context, Pipeline], None]:
def unset_inner(ctx: Context, pipeline: Pipeline) -> None:
unset_handler(ctx, pipeline, slots)
return

return unset_inner


@validate_call
def fill_template(slots: Optional[List[str]] = None):
def fill_template(slots: Optional[List[str]] = None) -> Callable[[Context, Pipeline], None]:
"""
Fill the response template in the current node.
Response should be an instance of :py:class:`~.Message`.
Expand All @@ -51,7 +51,7 @@ def fill_template(slots: Optional[List[str]] = None):
:param slots: Slot names to use. If this parameter is omitted, all slots will be used.
"""

def fill_inner(ctx: Context, pipeline: Pipeline):
def fill_inner(ctx: Context, pipeline: Pipeline) -> None:
# get current node response
response = ctx.current_node.response
if callable(response):
Expand Down
6 changes: 3 additions & 3 deletions dff/script/slots/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
This module is for functions that should be executed at the response stage.
They produce the response that will be ultimately given to the user.
"""
from typing import Optional, List
from typing import Callable, Optional, List

from pydantic import validate_call

Expand All @@ -15,7 +15,7 @@


@validate_call
def fill_template(template: Message, slots: Optional[List[str]] = None):
def fill_template(template: Message, slots: Optional[List[str]] = None) -> Callable[[Context, Pipeline], Message]:
"""
Fill a template with slot values.
Response should be an instance of :py:class:`~.Message` class.
Expand All @@ -24,7 +24,7 @@ def fill_template(template: Message, slots: Optional[List[str]] = None):
:param slots: Slot names to use. If this parameter is omitted, all slots will be used.
"""

def fill_inner(ctx: Context, pipeline: Pipeline):
def fill_inner(ctx: Context, pipeline: Pipeline) -> Message:
new_template = template.model_copy()
new_text = get_filled_template(template.text, ctx, pipeline, slots)
new_template.text = new_text
Expand Down
51 changes: 26 additions & 25 deletions dff/script/slots/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,19 @@ class BaseSlot(BaseModel, ABC, arbitrary_types_allowed=True):
children: Optional[Dict[str, "BaseSlot"]]

@field_validator("name", mode="before")
def validate_name(cls, name: str):
def validate_name(cls, name: str) -> str:
if "/" in name:
raise ValueError("Character `/` cannot be used in slot names.")
return name

def __init__(self, name: str, **data):
def __init__(self, name: str, **data) -> None:
super().__init__(name=name, **data)

def __deepcopy__(self, *args, **kwargs):
def __deepcopy__(self) -> "BaseSlot":
return copy(self)

def __eq__(self, other: "BaseSlot"):
return self.dict(exclude={"name"}) == other.dict(exclude={"name"})
def __eq__(self, other: "BaseSlot") -> bool:
return self.model_dump(exclude={"name"}) == other.model_dump(exclude={"name"})

def has_children(self) -> bool:
return self.children is not None and len(self.children) > 0
Expand Down Expand Up @@ -79,16 +79,16 @@ class _GroupSlot(BaseSlot):
children: Dict[str, BaseSlot] = Field(default_factory=dict)

@field_validator("children", mode="before")
def validate_children(cls, children, values: dict):
def validate_children(cls, children: Iterable, values: Dict[str, Any]) -> Dict[str, BaseSlot]:
if not isinstance(children, dict) and isinstance(children, Iterable):
children = {child.name: child for child in children}
if len(children) == 0:
name = values["name"]
raise ValueError(f"Error in slot {name}: group slot should have at least one child or more.")
return children

def is_set(self):
def is_set_inner(ctx: Context, pipeline: Pipeline):
def is_set(self) -> Callable[[Context, Pipeline], bool]:
def is_set_inner(ctx: Context, pipeline: Pipeline) -> bool:
return all([child.is_set()(ctx, pipeline) for child in self.children.values()])

return is_set_inner
Expand All @@ -105,14 +105,14 @@ def get_inner(ctx: Context, pipeline: Pipeline) -> Dict[str, Union[str, None]]:

return get_inner

def unset_value(self):
def unset_inner(ctx: Context, pipeline: Pipeline):
def unset_value(self) -> Callable[[Context, Pipeline], None]:
def unset_inner(ctx: Context, pipeline: Pipeline) -> None:
for child in self.children.values():
child.unset_value()(ctx, pipeline)

return unset_inner

def fill_template(self, template: str) -> Callable:
def fill_template(self, template: str) -> Callable[[Context, Pipeline], str]:
def fill_inner(ctx: Context, pipeline: Pipeline) -> str:
new_template = template
for _, child in self.children.items():
Expand All @@ -122,7 +122,7 @@ def fill_inner(ctx: Context, pipeline: Pipeline) -> str:

return fill_inner

def extract_value(self, ctx: Context, pipeline: Pipeline):
def extract_value(self, ctx: Context, pipeline: Pipeline) -> Any:
for child in self.children.values():
_ = child.extract_value(ctx, pipeline)
return self.get_value()(ctx, pipeline)
Expand All @@ -133,6 +133,7 @@ class RootSlot(_GroupSlot):
"""
Root slot is a universally unique slot group that automatically
registers all the other slots and makes them globally available.
"""

@staticmethod
Expand All @@ -148,7 +149,7 @@ def flatten_slot_tree(node: BaseSlot) -> Tuple[Dict[str, BaseSlot], Dict[str, Ba
remove_nodes.update(child_remove_nodes)
return add_nodes, remove_nodes

def add_slots(self, slots: Union[BaseSlot, Iterable]):
def add_slots(self, slots: Union[BaseSlot, Iterable]) -> None:
if isinstance(slots, BaseSlot):
add_nodes, _ = self.flatten_slot_tree(slots)
self.children.update(add_nodes)
Expand All @@ -161,7 +162,7 @@ def add_slots(self, slots: Union[BaseSlot, Iterable]):


class ChildSlot(BaseSlot):
def __init__(self, *, name, **kwargs):
def __init__(self, *, name, **kwargs) -> None:
super().__init__(name=name, **kwargs)
root_slot.add_slots(self)

Expand All @@ -188,8 +189,8 @@ class ValueSlot(ChildSlot):
children: None = Field(None)
value: Any = None

def is_set(self):
def is_set_inner(ctx: Context, _: Pipeline):
def is_set(self) -> Callable[[Context, Pipeline], bool]:
def is_set_inner(ctx: Context, _: Pipeline) -> bool:
return bool(ctx.framework_states.get(SLOT_STORAGE_KEY, {}).get(self.name))

return is_set_inner
Expand All @@ -200,14 +201,14 @@ def get_inner(ctx: Context, _: Pipeline) -> Union[str, None]:

return get_inner

def unset_value(self):
def unset_inner(ctx: Context, _: Pipeline):
def unset_value(self) -> Callable[[Context, Pipeline], None]:
def unset_inner(ctx: Context, _: Pipeline) -> None:
ctx.framework_states.setdefault(SLOT_STORAGE_KEY, {})
ctx.framework_states[SLOT_STORAGE_KEY][self.name] = None

return unset_inner

def fill_template(self, template: str) -> Callable[[Context, Pipeline], str]:
def fill_template(self, template: str) -> Callable[[Context, Pipeline], Union[str, None]]:
"""
Value Slot's `fill_template` method does not perform template filling on its own, but allows you
to cut corners on some standard operations. E. g., if you include the following snippet in
Expand Down Expand Up @@ -258,8 +259,8 @@ class RegexpSlot(ValueSlot):
regexp: str
match_group_idx: int = 0

def fill_template(self, template: str) -> Callable:
def fill_inner(ctx: Context, pipeline: Pipeline):
def fill_template(self, template: str) -> Callable[[Context, Pipeline], str]:
def fill_inner(ctx: Context, pipeline: Pipeline) -> str:
checked_template = super(RegexpSlot, self).fill_template(template)(ctx, pipeline)
if checked_template is None: # the check returning None means that an error has occured.
return template
Expand All @@ -269,7 +270,7 @@ def fill_inner(ctx: Context, pipeline: Pipeline):

return fill_inner

def extract_value(self, ctx: Context, _: Pipeline):
def extract_value(self, ctx: Context, _: Pipeline) -> Any:
search = re.search(self.regexp, ctx.last_request.text)
self.value = search.group(self.match_group_idx) if search else None
return self.value
Expand All @@ -284,8 +285,8 @@ class FunctionSlot(ValueSlot):

func: Callable[[str], str]

def fill_template(self, template: str) -> Callable:
def fill_inner(ctx: Context, pipeline: Pipeline):
def fill_template(self, template: str) -> Callable[[Context, Pipeline], str]:
def fill_inner(ctx: Context, pipeline: Pipeline) -> str:
checked_template = super(FunctionSlot, self).fill_template(template)(ctx, pipeline)
if not checked_template: # the check returning None means that an error has occured.
return template
Expand All @@ -295,6 +296,6 @@ def fill_inner(ctx: Context, pipeline: Pipeline):

return fill_inner

def extract_value(self, ctx: Context, _: Pipeline):
def extract_value(self, ctx: Context, _: Pipeline) -> Any:
self.value = self.func(ctx.last_request.text)
return self.value
Loading

0 comments on commit 25a2eba

Please sign in to comment.