Skip to content

Commit

Permalink
test(embeddings_only): parametrize config fixtures to add test for bo…
Browse files Browse the repository at this point in the history
…t colang 1 and 2
  • Loading branch information
Pouyanpi committed Sep 9, 2024
1 parent b5b04a5 commit 11ddfc7
Showing 1 changed file with 70 additions and 26 deletions.
96 changes: 70 additions & 26 deletions tests/test_embeddings_only_user_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from nemoguardrails import LLMRails, RailsConfig
from nemoguardrails.actions.llm.utils import LLMCallException
from nemoguardrails.llm.filters import colang
from tests.utils import TestChat


Expand Down Expand Up @@ -49,63 +50,93 @@ def config():
def colang_2_config():
return RailsConfig.from_content(
"""
import core
import llm
import core
import llm
flow main
activate greeting
activate llm continuation
flow main
activate greeting
activate llm continuation
flow user expressed greeting
user said "hi"
flow user expressed greeting
user said "hi"
or user said "hello"
flow bot express greeting
bot say "Hello!"
flow bot express greeting
bot say "Hello!"
flow greeting
user expressed greeting
bot express greeting
""",
flow greeting
user expressed greeting
bot express greeting
""",
"""
colang_version: 2.x
rails:
dialog:
user_messages:
embeddings_only: True
embeddings_only_similarity_threshold: 0.8
embeddings_only_fallback_intent: "user expressed greeting"
""",
colang_version: "2.x"
rails:
dialog:
user_messages:
embeddings_only: True
embeddings_only_similarity_threshold: 0.8
embeddings_only_fallback_intent: "user expressed greeting"
""",
)


@pytest.mark.parametrize("config", ["config", "colang_2_config"], indirect=True)
def test_greeting(config):
def test_greeting_1(config):
"""Test that the bot responds with 'Hello!' when the user says 'hello'."""

chat = TestChat(
config,
llm_completions=[],
)

chat >> "hello"
chat << "Hello!"


@pytest.mark.parametrize("config", ["config", "colang_2_config"], indirect=True)
def test_greeting_2(colang_2_config):
"""Test that the bot responds with 'Hello!' when the user says 'hello'."""

chat = TestChat(
colang_2_config,
llm_completions=[],
)

chat >> "hi"
chat << "Hello!"


def test_error_when_embeddings_only_is_false(config):
"""Test that an error is raised when the 'embeddings_only' option is False."""

# Check that if we deactivate the embeddings_only option we get an error
config.rails.dialog.user_messages.embeddings_only = False
chat = TestChat(
config,
llm_completions=[],
)

with pytest.raises(LLMCallException):
chat >> "how is your day?"
chat << "Hello!"


def test_error_when_embeddings_only_is_false_2(colang_2_config):
"""Test that an error is raised when the 'embeddings_only' option is False."""

# Check that if we deactivate the embeddings_only option we get an error
colang_2_config.rails.dialog.user_messages.embeddings_only = False
chat = TestChat(
colang_2_config,
llm_completions=[],
)

with pytest.raises(LLMCallException):
chat >> "hello"
chat >> "how is your day?"
chat << "Hello!"


@pytest.mark.parametrize("config", ["config", "colang_2_config"], indirect=True)
def test_fallback_intent(config):
"""Test that the bot uses the fallback intent when it doesn't recognize the user's message."""

rails = LLMRails(config)
res = rails.generate(messages=[{"role": "user", "content": "lets use fallback"}])
assert res["content"] == "Hello!"
Expand All @@ -114,3 +145,16 @@ def test_fallback_intent(config):
rails = LLMRails(config)
with pytest.raises(LLMCallException):
rails.generate(messages=[{"role": "user", "content": "lets use fallback"}])


def test_fallback_intent_2(colang_2_config):
"""Test that the bot uses the fallback intent when it doesn't recognize the user's message."""

rails = LLMRails(colang_2_config)
res = rails.generate(messages=[{"role": "user", "content": "lets use fallback"}])
assert res["content"] == "Hello!"

colang_2_config.rails.dialog.user_messages.embeddings_only_fallback_intent = None
rails = LLMRails(colang_2_config)
with pytest.raises(LLMCallException):
rails.generate(messages=[{"role": "user", "content": "lets use fallback"}])

0 comments on commit 11ddfc7

Please sign in to comment.