Skip to content

Commit

Permalink
Add support for arbitrary input events for TestChat and make standard…
Browse files Browse the repository at this point in the history
… user message UMIM compliant.
  • Loading branch information
sklinglernv committed Oct 22, 2024
1 parent 3a9897b commit 31f1ae8
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 10 deletions.
33 changes: 23 additions & 10 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import asyncio
import json
import sys
from typing import Any, Dict, Iterable, List, Mapping, Optional
from typing import Any, Dict, Iterable, List, Mapping, Optional, Union

from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
Expand All @@ -32,7 +32,7 @@
create_flow_configs_from_flow_list,
)
from nemoguardrails.colang.v2_x.runtime.statemachine import initialize_state
from nemoguardrails.utils import EnhancedJsonEncoder, new_event_dict
from nemoguardrails.utils import EnhancedJsonEncoder, new_event_dict, new_uuid


class FakeLLM(LLM):
Expand Down Expand Up @@ -157,16 +157,29 @@ def __init__(
self.state,
)

def user(self, msg: str):
def user(self, msg: Union[str, dict]):
if self.config.colang_version == "1.0":
self.history.append({"role": "user", "content": msg})
elif self.config.colang_version == "2.x":
self.input_events.append(
{
"type": "UtteranceUserActionFinished",
"final_transcript": msg,
}
)
if isinstance(msg, str):
uid = new_uuid()
self.input_events.extend(
[
new_event_dict("UtteranceUserActionStarted", action_uid=uid),
new_event_dict(
"UtteranceUserActionFinished",
final_transcript=msg,
action_uid=uid,
is_success=True,
),
]
)
elif "type" in msg:
self.input_events.append(msg)
else:
raise ValueError(
f"Invalid user message: {msg}. Must be either str or event"
)
else:
raise Exception(f"Invalid colang version: {self.config.colang_version}")

Expand Down Expand Up @@ -223,7 +236,7 @@ async def bot_async(self, msg: str):
), f"Expected `{msg}` and received `{result['content']}`"
self.history.append(result)

def __rshift__(self, msg: str):
def __rshift__(self, msg: Union[str, dict]):
self.user(msg)

def __lshift__(self, msg: str):
Expand Down
47 changes: 47 additions & 0 deletions tests/v2_x/test_event_mechanics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1667,5 +1667,52 @@ def test_runtime_exception_handling_2():
)


def test_user_message_generates_started_and_finished():
"""Test queuing of action events."""
config = RailsConfig.from_content(
colang_content="""
flow main
match UtteranceUserActionStarted()
match UtteranceUserActionFinished(final_transcript="yes")
start UtteranceBotAction(script="ok")
""",
yaml_content="""
colang_version: "2.x"
""",
)

chat = TestChat(
config,
llm_completions=[],
)

chat >> "yes"
chat << "ok"


def test_handling_arbitrary_events_through_test_chat():
"""Test queuing of action events."""
config = RailsConfig.from_content(
colang_content="""
flow main
match CustomEvent(name="test")
match EventA()
start UtteranceBotAction(script="started")
""",
yaml_content="""
colang_version: "2.x"
""",
)

chat = TestChat(
config,
llm_completions=[],
)

chat >> {"type": "CustomEvent", "name": "test"}
chat >> {"type": "EventA"}
chat << "started"


if __name__ == "__main__":
test_event_match_group()

0 comments on commit 31f1ae8

Please sign in to comment.