Skip to content

Commit

Permalink
move custom actions used in tests to conftest.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ancalita committed Jun 4, 2024
1 parent 87e5341 commit 609e2b8
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 137 deletions.
133 changes: 133 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
from typing import List, Dict, Text, Any

from sanic import Sanic

import pytest

from rasa_sdk import Action, FormValidationAction, Tracker, ValidationAction
from rasa_sdk.events import SlotSet
from rasa_sdk.executor import CollectingDispatcher
from rasa_sdk.types import DomainDict

Sanic.test_mode = True


Expand All @@ -14,3 +23,127 @@ def get_stack():
}
]
return dialogue_stack


class CustomAsyncAction(Action):
def name(cls) -> Text:
return "custom_async_action"

async def run(
self,
dispatcher: CollectingDispatcher,
tracker: Tracker,
domain: DomainDict,
) -> List[Dict[Text, Any]]:
return [SlotSet("test", "foo"), SlotSet("test2", "boo")]


class CustomAction(Action):
def name(cls) -> Text:
return "custom_action"

def run(
self,
dispatcher: CollectingDispatcher,
tracker: Tracker,
domain: DomainDict,
) -> List[Dict[Text, Any]]:
return [SlotSet("test", "bar")]


class CustomActionRaisingException(Action):
def name(cls) -> Text:
return "custom_action_exception"

def run(
self,
dispatcher: CollectingDispatcher,
tracker: Tracker,
domain: DomainDict,
) -> List[Dict[Text, Any]]:
raise Exception("test exception")


class CustomActionWithDialogueStack(Action):
def name(cls) -> Text:
return "custom_action_with_dialogue_stack"

def run(
self,
dispatcher: CollectingDispatcher,
tracker: Tracker,
domain: DomainDict,
) -> List[Dict[Text, Any]]:
return [SlotSet("stack", tracker.stack)]


class MockFormValidationAction(FormValidationAction):
def __init__(self) -> None:
self.fail_if_undefined("run")

def fail_if_undefined(self, method_name: str) -> None:
if not (
hasattr(self.__class__.__base__, method_name)
and callable(getattr(self.__class__.__base__, method_name))
):
pytest.fail(
f"method '{method_name}' not found in {self.__class__.__base__}. "
f"This likely means the method was renamed, which means the "
f"instrumentation needs to be adapted!"
)

async def _extract_validation_events(
self,
dispatcher: "CollectingDispatcher",
tracker: "Tracker",
domain: "DomainDict",
) -> None:
return tracker.events

def name(self) -> str:
return "mock_form_validation_action"


class MockValidationAction(ValidationAction):
def __init__(self) -> None:
self.fail_if_undefined("run")

def fail_if_undefined(self, method_name: Text) -> None:
if not (
hasattr(self.__class__.__base__, method_name)
and callable(getattr(self.__class__.__base__, method_name))
):
pytest.fail(
f"method '{method_name}' not found in {self.__class__.__base__}. "
f"This likely means the method was renamed, which means the "
f"instrumentation needs to be adapted!"
)

async def run(
self,
dispatcher: "CollectingDispatcher",
tracker: "Tracker",
domain: "DomainDict",
) -> None:
pass

def name(self) -> Text:
return "mock_validation_action"

async def _extract_validation_events(
self,
dispatcher: "CollectingDispatcher",
tracker: "Tracker",
domain: "DomainDict",
) -> None:
return tracker.events


class SubclassTestActionA(Action):
def name(self):
return "subclass_test_action_a"


class SubclassTestActionB(SubclassTestActionA):
def name(self):
return "subclass_test_action_b"
Empty file removed tests/test_actions/__init__.py
Empty file.
132 changes: 0 additions & 132 deletions tests/test_actions/test_actions.py

This file was deleted.

4 changes: 2 additions & 2 deletions tests/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_server_list_actions_returns_200(sanic_app: Sanic):
assert len(response.json) == 9
print(response.json)
expected = [
# defined in tests/test_actions
# defined in tests/conftest.py
{"name": "custom_async_action"},
{"name": "custom_action"},
{"name": "custom_action_exception"},
Expand All @@ -46,7 +46,7 @@ def test_server_list_actions_returns_200(sanic_app: Sanic):
{"name": "mock_form_validation_action"},
# defined in tests/test_forms.py
{"name": "some_form"},
# defined in tests/test_actions
# defined in tests/conftest.py
{"name": "subclass_test_action_b"},
]
assert response.json == expected
Expand Down
2 changes: 1 addition & 1 deletion tests/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import pytest
from rasa_sdk.executor import ActionExecutor, CollectingDispatcher
from tests.test_actions.test_actions import SubclassTestActionA, SubclassTestActionB
from tests.conftest import SubclassTestActionA, SubclassTestActionB

TEST_PACKAGE_BASE = "tests/executor_test_packages"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter

from rasa_sdk.tracing.instrumentation import instrumentation
from tests.test_actions.test_actions import MockFormValidationAction
from tests.conftest import MockFormValidationAction
from rasa_sdk import Tracker
from rasa_sdk.executor import CollectingDispatcher
from rasa_sdk.events import ActionExecuted, SlotSet
Expand Down
2 changes: 1 addition & 1 deletion tests/tracing/instrumentation/test_validation_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter

from rasa_sdk.tracing.instrumentation import instrumentation
from tests.test_actions.test_actions import (
from tests.conftest import (
MockValidationAction,
)
from rasa_sdk import Tracker
Expand Down

0 comments on commit 609e2b8

Please sign in to comment.