Skip to content

Commit

Permalink
Enable mypy override check and fix errors (RasaHQ#10965)
Browse files Browse the repository at this point in the history
* enable override check, fix errors in events.py and utils package

* fix errors in shared package

* fix errors in nlu package

* fix errors in policies

* fix last errors in core package

* docstring fixes

* use generics to fix override errors in brokers and events

* address review comments

* make attribute private

* fix error in tracker featurizer

* modify tracker with cached states

* refactor tracker stores with mixin serialisation class

* undo serialise_tracker method as staticmethod, change type in tracker from_dict method

* revert to staticmethod for tracker stores, fix type in __contains__

* address final review comments

* fix failed tests
  • Loading branch information
ancalita authored Mar 10, 2022
1 parent b6fd93e commit 291f6c6
Show file tree
Hide file tree
Showing 23 changed files with 232 additions and 97 deletions.
1 change: 1 addition & 0 deletions changelog/9097.misc.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Enable `mypy` `override` check and fix any resulting errors.
9 changes: 6 additions & 3 deletions rasa/core/brokers/broker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from asyncio import AbstractEventLoop
from typing import Any, Dict, Text, Optional, Union
from typing import Any, Dict, Text, Optional, Union, TypeVar, Type

import aiormq

Expand All @@ -12,6 +12,9 @@
logger = logging.getLogger(__name__)


EB = TypeVar("EB", bound="EventBroker")


class EventBroker:
"""Base class for any event broker implementation."""

Expand Down Expand Up @@ -39,10 +42,10 @@ async def create(

@classmethod
async def from_endpoint_config(
cls,
cls: Type[EB],
broker_config: EndpointConfig,
event_loop: Optional[AbstractEventLoop] = None,
) -> "EventBroker":
) -> Optional[EB]:
"""Creates an `EventBroker` from the endpoint configuration.
Args:
Expand Down
3 changes: 2 additions & 1 deletion rasa/core/channels/slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,10 @@ async def send_image_url(
channel=recipient, as_user=True, text=image, blocks=[image_block]
)

async def send_attachment(
async def send_attachment( # type: ignore[override]
self, recipient_id: Text, attachment: Dict[Text, Any], **kwargs: Any
) -> None:
"""Sends message with attachment."""
recipient = self.slack_channel or recipient_id
await self._post_message(
channel=recipient, as_user=True, attachments=[attachment], **kwargs
Expand Down
2 changes: 1 addition & 1 deletion rasa/core/channels/socketio.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ async def send_custom_json(

await self.sio.emit(self.bot_message_evt, **json_message)

async def send_attachment(
async def send_attachment( # type: ignore[override]
self, recipient_id: Text, attachment: Dict[Text, Any], **kwargs: Any
) -> None:
"""Sends an attachment to the user."""
Expand Down
8 changes: 4 additions & 4 deletions rasa/core/featurizers/tracker_featurizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,9 +508,7 @@ def training_states_labels_and_entities(
domain: Domain,
omit_unset_slots: bool = False,
ignore_action_unlikely_intent: bool = False,
) -> Tuple[
List[List[State]], List[List[Optional[Text]]], List[List[Dict[Text, Any]]]
]:
) -> Tuple[List[List[State]], List[List[Text]], List[List[Dict[Text, Any]]]]:
"""Transforms trackers to states, action labels, and entity data.
Args:
Expand Down Expand Up @@ -561,7 +559,9 @@ def training_states_labels_and_entities(
if not event.unpredictable:
# only actions which can be
# predicted at a stories start
actions.append(event.action_name or event.action_text)
action = event.action_name or event.action_text
if action is not None:
actions.append(action)
entities.append(entity_data)
else:
# unpredictable actions can be
Expand Down
5 changes: 4 additions & 1 deletion rasa/core/policies/memoization.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,10 @@ def _create_lookup_from_states(

return lookup

def _create_feature_key(self, states: List[State]) -> Text:
def _create_feature_key(self, states: List[State]) -> Optional[Text]:
if not states:
return None

# we sort keys to make sure that the same states
# represented as dictionaries have the same json strings
# quotes are removed for aesthetic reasons
Expand Down
3 changes: 2 additions & 1 deletion rasa/core/policies/ted_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,7 @@ def train(
training_trackers: List[TrackerWithCachedStates],
domain: Domain,
precomputations: Optional[MessageContainerForCoreFeaturization] = None,
**kwargs: Any,
) -> Resource:
"""Trains the policy (see parent class for full docstring)."""
if not training_trackers:
Expand Down Expand Up @@ -809,8 +810,8 @@ def predict_action_probabilities(
self,
tracker: DialogueStateTracker,
domain: Domain,
precomputations: Optional[MessageContainerForCoreFeaturization] = None,
rule_only_data: Optional[Dict[Text, Any]] = None,
precomputations: Optional[MessageContainerForCoreFeaturization] = None,
**kwargs: Any,
) -> PolicyPrediction:
"""Predicts the next action (see parent class for full docstring)."""
Expand Down
6 changes: 3 additions & 3 deletions rasa/core/policies/unexpected_intent_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def __init__(

common.mark_as_experimental_feature("UnexpecTED Intent Policy")

def _standard_featurizer(self) -> TrackerFeaturizer:
def _standard_featurizer(self) -> IntentMaxHistoryTrackerFeaturizer:
return IntentMaxHistoryTrackerFeaturizer(
IntentTokenizerSingleStateFeaturizer(),
max_history=self.config.get(POLICY_MAX_HISTORY),
Expand Down Expand Up @@ -565,18 +565,18 @@ def predict_action_probabilities(
self,
tracker: DialogueStateTracker,
domain: Domain,
precomputations: Optional[MessageContainerForCoreFeaturization] = None,
rule_only_data: Optional[Dict[Text, Any]] = None,
precomputations: Optional[MessageContainerForCoreFeaturization] = None,
**kwargs: Any,
) -> PolicyPrediction:
"""Predicts the next action the bot should take after seeing the tracker.
Args:
tracker: Tracker containing past conversation events.
domain: Domain of the assistant.
precomputations: Contains precomputed features and attributes.
rule_only_data: Slots and loops which are specific to rules and hence
should be ignored by this policy.
precomputations: Contains precomputed features and attributes.
Returns:
The policy's prediction (e.g. the probabilities for the actions).
Expand Down
107 changes: 78 additions & 29 deletions rasa/core/tracker_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
Union,
TYPE_CHECKING,
Generator,
TypeVar,
Generic,
)

from boto3.dynamodb.conditions import Key
Expand Down Expand Up @@ -68,6 +70,40 @@ class TrackerDeserialisationException(RasaException):
"""Raised when an error is encountered while deserialising a tracker."""


SerializationType = TypeVar("SerializationType")


class SerializedTrackerRepresentation(Generic[SerializationType]):
"""Mixin class for specifying different serialization methods per tracker store."""

@staticmethod
def serialise_tracker(tracker: DialogueStateTracker) -> SerializationType:
"""Requires implementation to return representation of tracker."""
raise NotImplementedError()


class SerializedTrackerAsText(SerializedTrackerRepresentation[Text]):
"""Mixin class that returns the serialized tracker as string."""

@staticmethod
def serialise_tracker(tracker: DialogueStateTracker) -> Text:
"""Serializes the tracker, returns representation of the tracker."""
dialogue = tracker.as_dialogue()

return json.dumps(dialogue.as_dict())


class SerializedTrackerAsDict(SerializedTrackerRepresentation[Dict]):
"""Mixin class that returns the serialized tracker as dictionary."""

@staticmethod
def serialise_tracker(tracker: DialogueStateTracker) -> Dict:
"""Serializes the tracker, returns representation of the tracker."""
d = tracker.as_dialogue().as_dict()
d.update({"sender_id": tracker.sender_id})
return d


class TrackerStore:
"""Represents common behavior and interface for all `TrackerStore`s."""

Expand All @@ -85,7 +121,7 @@ def __init__(
destination.
kwargs: Additional kwargs.
"""
self.domain = domain or Domain.empty()
self._domain = domain or Domain.empty()
self.event_broker = event_broker
self.max_event_history = None

Expand Down Expand Up @@ -141,10 +177,10 @@ def get_or_create_tracker(
return tracker

def init_tracker(self, sender_id: Text) -> "DialogueStateTracker":
"""Returns a Dialogue State Tracker"""
"""Returns a Dialogue State Tracker."""
return DialogueStateTracker(
sender_id,
self.domain.slots if self.domain else None,
self.domain.slots,
max_event_history=self.max_event_history,
)

Expand Down Expand Up @@ -237,16 +273,9 @@ def number_of_existing_events(self, sender_id: Text) -> int:
return len(old_tracker.events) if old_tracker else 0

def keys(self) -> Iterable[Text]:
"""Returns the set of values for the tracker store's primary key"""
"""Returns the set of values for the tracker store's primary key."""
raise NotImplementedError()

@staticmethod
def serialise_tracker(tracker: DialogueStateTracker) -> Text:
"""Serializes the tracker, returns representation of the tracker."""
dialogue = tracker.as_dialogue()

return json.dumps(dialogue.as_dict())

def deserialise_tracker(
self, sender_id: Text, serialised_tracker: Union[Text, bytes]
) -> Optional[DialogueStateTracker]:
Expand All @@ -266,16 +295,26 @@ def deserialise_tracker(

return tracker

@property
def domain(self) -> Domain:
"""Returns the domain of the tracker store."""
return self._domain

class InMemoryTrackerStore(TrackerStore):
"""Stores conversation history in memory"""
@domain.setter
def domain(self, domain: Optional[Domain]) -> None:
self._domain = domain or Domain.empty()


class InMemoryTrackerStore(TrackerStore, SerializedTrackerAsText):
"""Stores conversation history in memory."""

def __init__(
self,
domain: Domain,
event_broker: Optional[EventBroker] = None,
**kwargs: Dict[Text, Any],
) -> None:
"""Initializes the tracker store."""
self.store: Dict[Text, Text] = {}
super().__init__(domain, event_broker, **kwargs)

Expand All @@ -286,6 +325,7 @@ def save(self, tracker: DialogueStateTracker) -> None:
self.store[tracker.sender_id] = serialised

def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]:
"""Returns tracker matching sender_id."""
if sender_id in self.store:
logger.debug(f"Recreating tracker for id '{sender_id}'")
return self.deserialise_tracker(sender_id, self.store[sender_id])
Expand All @@ -295,12 +335,12 @@ def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]:
return None

def keys(self) -> Iterable[Text]:
"""Returns sender_ids of the Tracker Store in memory"""
"""Returns sender_ids of the Tracker Store in memory."""
return self.store.keys()


class RedisTrackerStore(TrackerStore):
"""Stores conversation history in Redis"""
class RedisTrackerStore(TrackerStore, SerializedTrackerAsText):
"""Stores conversation history in Redis."""

def __init__(
self,
Expand All @@ -318,6 +358,7 @@ def __init__(
ssl_ca_certs: Optional[Text] = None,
**kwargs: Dict[Text, Any],
) -> None:
"""Initializes the tracker store."""
import redis

self.red = redis.StrictRedis(
Expand Down Expand Up @@ -388,8 +429,8 @@ def keys(self) -> Iterable[Text]:
return self.red.keys(self.key_prefix + "*")


class DynamoTrackerStore(TrackerStore):
"""Stores conversation history in DynamoDB"""
class DynamoTrackerStore(TrackerStore, SerializedTrackerAsDict):
"""Stores conversation history in DynamoDB."""

def __init__(
self,
Expand Down Expand Up @@ -451,12 +492,17 @@ def save(self, tracker: DialogueStateTracker) -> None:

self.db.put_item(Item=serialized)

def serialise_tracker(self, tracker: "DialogueStateTracker") -> Dict:
"""Serializes the tracker, returns object with decimal types."""
d = tracker.as_dialogue().as_dict()
d.update({"sender_id": tracker.sender_id})
# DynamoDB cannot store `float`s, so we'll convert them to `Decimal`s
return core_utils.replace_floats_with_decimals(d)
@staticmethod
def serialise_tracker(
tracker: "DialogueStateTracker",
) -> Dict:
"""Serializes the tracker, returns object with decimal types.
DynamoDB cannot store `float`s, so we'll convert them to `Decimal`s.
"""
return core_utils.replace_floats_with_decimals(
SerializedTrackerAsDict.serialise_tracker(tracker)
)

def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]:
"""Retrieve dialogues for a sender_id in reverse-chronological order.
Expand Down Expand Up @@ -499,7 +545,7 @@ def keys(self) -> Iterable[Text]:
return sender_ids


class MongoTrackerStore(TrackerStore):
class MongoTrackerStore(TrackerStore, SerializedTrackerAsText):
"""Stores conversation history in Mongo.
Property methods:
Expand Down Expand Up @@ -778,7 +824,7 @@ def validate_port(port: Any) -> Optional[int]:
return port


class SQLTrackerStore(TrackerStore):
class SQLTrackerStore(TrackerStore, SerializedTrackerAsText):
"""Store which can save and retrieve trackers from an SQL database."""

Base: DeclarativeMeta = declarative_base()
Expand Down Expand Up @@ -1110,8 +1156,10 @@ def _additional_events(


class FailSafeTrackerStore(TrackerStore):
"""Wraps a tracker store so that we can fallback to a different tracker store in
case of errors."""
"""Tracker store wrapper.
Allows a fallback to a different tracker store in case of errors.
"""

def __init__(
self,
Expand All @@ -1134,7 +1182,8 @@ def __init__(
super().__init__(tracker_store.domain, tracker_store.event_broker)

@property
def domain(self) -> Optional[Domain]:
def domain(self) -> Domain:
"""Returns the domain of the primary tracker store."""
return self._tracker_store.domain

@domain.setter
Expand Down
1 change: 1 addition & 0 deletions rasa/nlu/classifiers/logistic_regression_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def load(
model_storage: ModelStorage,
resource: Resource,
execution_context: ExecutionContext,
**kwargs: Any,
) -> GraphComponent:
"""Loads trained component (see parent class for full docstring)."""
try:
Expand Down
2 changes: 1 addition & 1 deletion rasa/nlu/persistor.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def retrieve(self, model_name: Text, target_path: Text) -> None:
self._copy(os.path.basename(tar_name), target_path)

@abc.abstractmethod
def _retrieve_tar(self, filename: Text) -> Text:
def _retrieve_tar(self, filename: Text) -> None:
"""Downloads a model previously persisted to cloud storage."""
raise NotImplementedError

Expand Down
4 changes: 3 additions & 1 deletion rasa/nlu/selectors/response_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,9 @@ def label_sub_key(self) -> Text:
return LABEL_SUB_KEY

@staticmethod
def model_class(use_text_as_label: bool) -> Type[RasaModel]:
def model_class( # type: ignore[override]
use_text_as_label: bool,
) -> Type[RasaModel]:
"""Returns model class."""
if use_text_as_label:
return DIET2DIET
Expand Down
Loading

0 comments on commit 291f6c6

Please sign in to comment.