Skip to content

Commit

Permalink
fix: tests
Browse files Browse the repository at this point in the history
  • Loading branch information
skoob13 committed Oct 28, 2024
1 parent 067db1f commit b8b9f38
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 50 deletions.
Empty file removed ee/hogai/test/__init__.py
Empty file.
48 changes: 0 additions & 48 deletions ee/hogai/test/test_assistant.py

This file was deleted.

2 changes: 1 addition & 1 deletion ee/hogai/trends/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,4 +403,4 @@ def run(self, state: AssistantState, config: RunnableConfig):
if not intermediate_steps:
return state
action, _ = intermediate_steps[-1]
return {"intermediate_steps": (action, action.log)}
return {"intermediate_steps": [(action, action.log)]}
73 changes: 72 additions & 1 deletion ee/hogai/trends/test/test_nodes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import json
from unittest.mock import patch

from django.test import override_settings
from langchain_core.agents import AgentAction
from langchain_core.runnables import RunnableLambda

from ee.hogai.trends.nodes import CreateTrendsPlanNode, GenerateTrendsNode
from ee.hogai.trends.nodes import CreateTrendsPlanNode, GenerateTrendsNode, GenerateTrendsToolsNode
from ee.hogai.trends.utils import GenerateTrendOutputModel
from posthog.schema import AssistantMessage, ExperimentalAITrendsQuery, HumanMessage, VisualizationMessage
from posthog.test.base import APIBaseTest, ClickhouseTestMixin, _create_event, _create_person

Expand Down Expand Up @@ -92,6 +98,25 @@ class TestGenerateTrendsNode(ClickhouseTestMixin, APIBaseTest):
def setUp(self):
self.schema = ExperimentalAITrendsQuery(series=[])

def test_node_runs(self):
node = GenerateTrendsNode(self.team)
with patch("ee.hogai.trends.nodes.GenerateTrendsNode._model") as generator_model_mock:
generator_model_mock.return_value = RunnableLambda(
lambda _: GenerateTrendOutputModel(reasoning_steps=["step"], answer=self.schema).model_dump()
)
new_state = node.run(
{
"messages": [HumanMessage(content="Text")],
"plan": "Plan",
},
{},
)
self.assertNotIn("intermediate_steps", new_state)
self.assertEqual(
new_state["messages"],
[VisualizationMessage(answer=self.schema, plan="Plan", reasoning_steps=["step"])],
)

def test_agent_reconstructs_conversation(self):
node = GenerateTrendsNode(self.team)
history = node._reconstruct_conversation({"messages": [HumanMessage(content="Text")]})
Expand Down Expand Up @@ -203,3 +228,49 @@ def test_agent_reconstructs_conversation_and_merges_messages(self):
self.assertIn("Answer to this question:", history[5].content)
self.assertNotIn("{{question}}", history[5].content)
self.assertIn("Follow\nUp", history[5].content)

def test_failover_with_incorrect_schema(self):
node = GenerateTrendsNode(self.team)
with patch("ee.hogai.trends.nodes.GenerateTrendsNode._model") as generator_model_mock:
schema = GenerateTrendOutputModel(reasoning_steps=[], answer=None).model_dump()
# Emulate an incorrect JSON. It should be an object.
schema["answer"] = []
generator_model_mock.return_value = RunnableLambda(lambda _: json.dumps(schema))

new_state = node.run({"messages": [HumanMessage(content="Text")]}, {})
self.assertIn("intermediate_steps", new_state)
self.assertEqual(len(new_state["intermediate_steps"]), 1)

def test_agent_reconstructs_conversation_with_failover(self):
action = AgentAction(tool="fix", tool_input="validation error", log="exception")
node = GenerateTrendsNode(self.team)
history = node._reconstruct_conversation(
{
"messages": [HumanMessage(content="Text")],
"plan": "randomplan",
"intermediate_steps": [(action, "uniqexception")],
},
"uniqexception",
)
self.assertEqual(len(history), 4)
self.assertEqual(history[0].type, "human")
self.assertIn("mapping", history[0].content)
self.assertEqual(history[1].type, "human")
self.assertIn("the plan", history[1].content)
self.assertNotIn("{{plan}}", history[1].content)
self.assertIn("randomplan", history[1].content)
self.assertEqual(history[2].type, "human")
self.assertIn("Answer to this question:", history[2].content)
self.assertNotIn("{{question}}", history[2].content)
self.assertIn("Text", history[2].content)
self.assertEqual(history[3].type, "human")
self.assertIn("Pydantic", history[3].content)
self.assertIn("uniqexception", history[3].content)


class TestGenerateTrendsToolsNode(ClickhouseTestMixin, APIBaseTest):
def test_tools_node(self):
node = GenerateTrendsToolsNode(self.team)
action = AgentAction(tool="fix", tool_input="validation error", log="exception")
state = node.run({"messages": [], "intermediate_steps": [(action, None)]}, {})
self.assertEqual(state, {"intermediate_steps": [(action, "exception")]})

0 comments on commit b8b9f38

Please sign in to comment.