From 453ac8352b34a53799929036f3380348c138cd33 Mon Sep 17 00:00:00 2001 From: aleksandarmijat Date: Wed, 15 May 2024 17:00:55 +0200 Subject: [PATCH 1/6] Return 449 status code if the Domain is not stored or missing in the payload --- rasa_sdk/endpoint.py | 10 +++++++- rasa_sdk/executor.py | 55 +++++++++++++++++++++++++++++++++++++++--- rasa_sdk/interfaces.py | 11 +++++++++ 3 files changed, 72 insertions(+), 4 deletions(-) diff --git a/rasa_sdk/endpoint.py b/rasa_sdk/endpoint.py index eb922c51a..456f475bf 100644 --- a/rasa_sdk/endpoint.py +++ b/rasa_sdk/endpoint.py @@ -31,7 +31,11 @@ from rasa_sdk.cli.arguments import add_endpoint_arguments from rasa_sdk.constants import DEFAULT_KEEP_ALIVE_TIMEOUT, DEFAULT_SERVER_PORT from rasa_sdk.executor import ActionExecutor - from rasa_sdk.interfaces import ActionExecutionRejection, ActionNotFoundException + from rasa_sdk.interfaces import ( + ActionExecutionRejection, + ActionNotFoundException, + ActionMissingDomainException, + ) from rasa_sdk.plugin import plugin_manager from rasa_sdk.tracing.utils import get_tracer_and_context, set_span_attributes @@ -139,6 +143,10 @@ async def webhook(request: Request) -> HTTPResponse: logger.error(e) body = {"error": e.message, "action_name": e.action_name} return response.json(body, status=404) + except ActionMissingDomainException as e: + logger.error(e) + body = {"error": e.message, "action_name": e.action_name} + return response.json(body, status=449) set_span_attributes(span, action_call) diff --git a/rasa_sdk/executor.py b/rasa_sdk/executor.py index 5d3e7a9f4..fb3972edf 100644 --- a/rasa_sdk/executor.py +++ b/rasa_sdk/executor.py @@ -10,7 +10,12 @@ import sys import os -from rasa_sdk.interfaces import Tracker, ActionNotFoundException, Action +from rasa_sdk.interfaces import ( + Tracker, + ActionNotFoundException, + Action, + ActionMissingDomainException, +) from rasa_sdk import utils @@ -24,7 +29,6 @@ class CollectingDispatcher: """Send messages back to user""" def __init__(self) -> None: - self.messages: List[Dict[Text, Any]] = [] def utter_message( @@ -162,6 +166,8 @@ def __init__(self) -> None: self.actions: Dict[Text, Callable] = {} self._modules: Dict[Text, TimestampModule] = {} self._loaded: Set[Type[Action]] = set() + self.domain: Dict[Text, Any] = {} + self.domain_digest: Text = "" def register_action(self, action: Union[Type[Action], Action]) -> None: if inspect.isclass(action): @@ -380,6 +386,49 @@ def validate_events(events: List[Dict[Text, Any]], action_name: Text): # we won't append this to validated events -> will be ignored return validated + def is_domain_digest_valid(self, domain_digest: Text) -> bool: + """Check if the domain_digest is valid + + If the domain_digest is empty or different from the one provided, it is invalid. + + Args: + domain_digest: latest value provided to compare the current value with. + + Returns: + True if the domain_digest is valid, False otherwise. + """ + return bool(self.domain_digest) and self.domain_digest == domain_digest + + def get_domain( + self, payload: Dict[Text, Any], action_name: Text + ) -> Dict[Text, Any]: + """Retrieve the Domain dictionary. + + This method returns the proper domain if present, otherwise raises the error. + + Args: + payload: request payload containing the Domain data. + action_name: name of the action that should be executed. + + Returns: + The Domain dictionary. + """ + payload_domain = payload.get("domain") + payload_domain_digest = payload.get("domain_digest") + + # If digest is invalid (empty or mismatched) and no domain is available - raise the error + if ( + not self.is_domain_digest_valid(payload_domain_digest) + and not payload_domain + ): + raise ActionMissingDomainException(action_name) + + if payload_domain: + self.domain = payload_domain + self.domain_digest = payload_domain_digest + + return self.domain + async def run(self, action_call: "ActionCall") -> Optional[Dict[Text, Any]]: from rasa_sdk.interfaces import Tracker @@ -391,7 +440,7 @@ async def run(self, action_call: "ActionCall") -> Optional[Dict[Text, Any]]: raise ActionNotFoundException(action_name) tracker_json = action_call["tracker"] - domain = action_call.get("domain", {}) + domain = self.get_domain(action_call, action_name) tracker = Tracker.from_dict(tracker_json) dispatcher = CollectingDispatcher() diff --git a/rasa_sdk/interfaces.py b/rasa_sdk/interfaces.py index 79b81c597..ae66022ad 100644 --- a/rasa_sdk/interfaces.py +++ b/rasa_sdk/interfaces.py @@ -384,3 +384,14 @@ def __init__(self, action_name: Text, message: Optional[Text] = None) -> None: def __str__(self) -> Text: return self.message + + +class ActionMissingDomainException(Exception): + """Raising this exception when the domain is missing.""" + + def __init__(self, action_name: Text, message: Optional[Text] = None) -> None: + self.action_name = action_name + self.message = message or f"Domain context is missing." + + def __str__(self) -> Text: + return self.message From ab28db741450295495bdcf69c2764587893e375d Mon Sep 17 00:00:00 2001 From: aleksandarmijat Date: Wed, 22 May 2024 18:45:47 +0200 Subject: [PATCH 2/6] Update naming and description of the method --- rasa_sdk/executor.py | 27 +++++++++++++++++---------- rasa_sdk/interfaces.py | 2 +- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/rasa_sdk/executor.py b/rasa_sdk/executor.py index fb3972edf..942ade998 100644 --- a/rasa_sdk/executor.py +++ b/rasa_sdk/executor.py @@ -166,8 +166,8 @@ def __init__(self) -> None: self.actions: Dict[Text, Callable] = {} self._modules: Dict[Text, TimestampModule] = {} self._loaded: Set[Type[Action]] = set() - self.domain: Dict[Text, Any] = {} - self.domain_digest: Text = "" + self.domain: Dict[Text, Any] = None + self.domain_digest: Text = None def register_action(self, action: Union[Type[Action], Action]) -> None: if inspect.isclass(action): @@ -399,24 +399,31 @@ def is_domain_digest_valid(self, domain_digest: Text) -> bool: """ return bool(self.domain_digest) and self.domain_digest == domain_digest - def get_domain( + def update_and_return_domain( self, payload: Dict[Text, Any], action_name: Text ) -> Dict[Text, Any]: - """Retrieve the Domain dictionary. + """Validate the digest, store the domain if available, and return the domain. - This method returns the proper domain if present, otherwise raises the error. + This method validates the domain digest from the payload. + If the digest is invalid and no domain is provided, an exception is raised. + If domain data is available, it stores the domain and digest. + Finally, it returns the domain. Args: - payload: request payload containing the Domain data. - action_name: name of the action that should be executed. + payload: Request payload containing the domain data. + action_name: Name of the action that should be executed. Returns: - The Domain dictionary. + The domain dictionary. + + Raises: + ActionMissingDomainException: Invalid digest and no domain data available. + """ payload_domain = payload.get("domain") payload_domain_digest = payload.get("domain_digest") - # If digest is invalid (empty or mismatched) and no domain is available - raise the error + # If digest is invalid and no domain is available - raise the error if ( not self.is_domain_digest_valid(payload_domain_digest) and not payload_domain @@ -440,7 +447,7 @@ async def run(self, action_call: "ActionCall") -> Optional[Dict[Text, Any]]: raise ActionNotFoundException(action_name) tracker_json = action_call["tracker"] - domain = self.get_domain(action_call, action_name) + domain = self.update_and_return_domain(action_call, action_name) tracker = Tracker.from_dict(tracker_json) dispatcher = CollectingDispatcher() diff --git a/rasa_sdk/interfaces.py b/rasa_sdk/interfaces.py index ae66022ad..454ad24b7 100644 --- a/rasa_sdk/interfaces.py +++ b/rasa_sdk/interfaces.py @@ -391,7 +391,7 @@ class ActionMissingDomainException(Exception): def __init__(self, action_name: Text, message: Optional[Text] = None) -> None: self.action_name = action_name - self.message = message or f"Domain context is missing." + self.message = message or "Domain context is missing." def __str__(self) -> Text: return self.message From 4866d7c22d7aa82c3d7d6830380028f2a50605fe Mon Sep 17 00:00:00 2001 From: aleksandarmijat Date: Wed, 22 May 2024 19:08:44 +0200 Subject: [PATCH 3/6] Fix Code Quality check by updating the typing --- rasa_sdk/executor.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/rasa_sdk/executor.py b/rasa_sdk/executor.py index 942ade998..b87302efc 100644 --- a/rasa_sdk/executor.py +++ b/rasa_sdk/executor.py @@ -166,8 +166,8 @@ def __init__(self) -> None: self.actions: Dict[Text, Callable] = {} self._modules: Dict[Text, TimestampModule] = {} self._loaded: Set[Type[Action]] = set() - self.domain: Dict[Text, Any] = None - self.domain_digest: Text = None + self.domain: Optional[Dict[Text, Any]] = None + self.domain_digest: Optional[Text] = None def register_action(self, action: Union[Type[Action], Action]) -> None: if inspect.isclass(action): @@ -386,7 +386,7 @@ def validate_events(events: List[Dict[Text, Any]], action_name: Text): # we won't append this to validated events -> will be ignored return validated - def is_domain_digest_valid(self, domain_digest: Text) -> bool: + def is_domain_digest_valid(self, domain_digest: Optional[Text]) -> bool: """Check if the domain_digest is valid If the domain_digest is empty or different from the one provided, it is invalid. @@ -401,7 +401,7 @@ def is_domain_digest_valid(self, domain_digest: Text) -> bool: def update_and_return_domain( self, payload: Dict[Text, Any], action_name: Text - ) -> Dict[Text, Any]: + ) -> Optional[Dict[Text, Any]]: """Validate the digest, store the domain if available, and return the domain. This method validates the domain digest from the payload. @@ -436,7 +436,7 @@ def update_and_return_domain( return self.domain - async def run(self, action_call: "ActionCall") -> Optional[Dict[Text, Any]]: + async def run(self, action_call: Dict[Text, Any]) -> Optional[Dict[Text, Any]]: from rasa_sdk.interfaces import Tracker action_name = action_call.get("next_action") From 552f0115139308f82a47f985c48d60fc251e1044 Mon Sep 17 00:00:00 2001 From: aleksandarmijat Date: Wed, 22 May 2024 19:10:16 +0200 Subject: [PATCH 4/6] Remove unused import --- rasa_sdk/executor.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/rasa_sdk/executor.py b/rasa_sdk/executor.py index b87302efc..f1e41ad8a 100644 --- a/rasa_sdk/executor.py +++ b/rasa_sdk/executor.py @@ -19,9 +19,6 @@ from rasa_sdk import utils -if typing.TYPE_CHECKING: # pragma: no cover - from rasa_sdk.types import ActionCall - logger = logging.getLogger(__name__) From 798a885e0b22731245b8469d831da8bbe1a2565e Mon Sep 17 00:00:00 2001 From: aleksandarmijat Date: Wed, 22 May 2024 19:12:25 +0200 Subject: [PATCH 5/6] Remove unused import --- rasa_sdk/executor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/rasa_sdk/executor.py b/rasa_sdk/executor.py index f1e41ad8a..8698272ec 100644 --- a/rasa_sdk/executor.py +++ b/rasa_sdk/executor.py @@ -2,7 +2,6 @@ import inspect import logging import pkgutil -import typing import warnings from typing import Text, List, Dict, Any, Type, Union, Callable, Optional, Set, cast from collections import namedtuple From a4e9f596a23c7dee3ac31e9f267c144349c2e1ee Mon Sep 17 00:00:00 2001 From: aleksandarmijat Date: Wed, 22 May 2024 20:23:29 +0200 Subject: [PATCH 6/6] Fix tests --- rasa_sdk/executor.py | 2 +- tests/test_endpoint.py | 4 ++++ tests/tracing/instrumentation/test_tracing.py | 1 + 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/rasa_sdk/executor.py b/rasa_sdk/executor.py index 8698272ec..8ace7cfce 100644 --- a/rasa_sdk/executor.py +++ b/rasa_sdk/executor.py @@ -422,7 +422,7 @@ def update_and_return_domain( # If digest is invalid and no domain is available - raise the error if ( not self.is_domain_digest_valid(payload_domain_digest) - and not payload_domain + and payload_domain is None ): raise ActionMissingDomainException(action_name) diff --git a/tests/test_endpoint.py b/tests/test_endpoint.py index 9f4090ca1..bfbe90a58 100644 --- a/tests/test_endpoint.py +++ b/tests/test_endpoint.py @@ -59,6 +59,7 @@ def test_server_webhook_handles_action_exception(): data = { "next_action": "custom_action_exception", "tracker": {"sender_id": "1", "conversation_id": "default"}, + "domain": {}, } request, response = app.test_client.post("/webhook", data=json.dumps(data)) assert response.status == 500 @@ -70,6 +71,7 @@ def test_server_webhook_custom_action_returns_200(): data = { "next_action": "custom_action", "tracker": {"sender_id": "1", "conversation_id": "default"}, + "domain": {}, } request, response = app.test_client.post("/webhook", data=json.dumps(data)) events = response.json.get("events") @@ -82,6 +84,7 @@ def test_server_webhook_custom_async_action_returns_200(): data = { "next_action": "custom_async_action", "tracker": {"sender_id": "1", "conversation_id": "default"}, + "domain": {}, } request, response = app.test_client.post("/webhook", data=json.dumps(data)) events = response.json.get("events") @@ -140,6 +143,7 @@ def test_server_webhook_custom_action_with_dialogue_stack_returns_200( data = { "next_action": "custom_action_with_dialogue_stack", "tracker": {"sender_id": "1", "conversation_id": "default", **stack_state}, + "domain": {}, } _, response = app.test_client.post("/webhook", data=json.dumps(data)) events = response.json.get("events") diff --git a/tests/tracing/instrumentation/test_tracing.py b/tests/tracing/instrumentation/test_tracing.py index e1af9b46d..e045e99d3 100644 --- a/tests/tracing/instrumentation/test_tracing.py +++ b/tests/tracing/instrumentation/test_tracing.py @@ -39,6 +39,7 @@ def test_server_webhook_custom_action_is_instrumented( """Tests that the custom action is instrumented.""" data["next_action"] = action_name + data["domain"] = {} app = ep.create_app(action_package, tracer_provider=tracer_provider) _, response = app.test_client.post("/webhook", data=json.dumps(data))