Skip to content

Commit

Permalink
fix(product-assistant): trim redundant queries in the schema generator (
Browse files Browse the repository at this point in the history
  • Loading branch information
skoob13 authored Nov 26, 2024
1 parent 9b819fc commit e3ba870
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 3 deletions.
9 changes: 7 additions & 2 deletions ee/hogai/schema_generator/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,10 @@ def _construct_messages(
human_messages, visualization_messages = filter_visualization_conversation(messages)
first_ai_message = True

for human_message, ai_message in itertools.zip_longest(human_messages, visualization_messages):
for idx, (human_message, ai_message) in enumerate(
itertools.zip_longest(human_messages, visualization_messages)
):
# Plans go first
if ai_message:
conversation.append(
HumanMessagePromptTemplate.from_template(
Expand All @@ -161,14 +164,16 @@ def _construct_messages(
).format(plan=generated_plan)
)

# Then questions
if human_message:
conversation.append(
HumanMessagePromptTemplate.from_template(QUESTION_PROMPT, template_format="mustache").format(
question=human_message.content
)
)

if ai_message:
# Then schemas, but include only last generated schema because it doesn't need more context.
if ai_message and idx + 1 == len(visualization_messages):
conversation.append(
LangchainAssistantMessage(content=ai_message.answer.model_dump_json() if ai_message.answer else "")
)
Expand Down
67 changes: 67 additions & 0 deletions ee/hogai/schema_generator/test/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from ee.hogai.schema_generator.nodes import SchemaGeneratorNode, SchemaGeneratorToolsNode
from ee.hogai.schema_generator.utils import SchemaGeneratorOutput
from posthog.schema import (
AssistantMessage,
AssistantTrendsQuery,
FailureMessage,
HumanMessage,
RouterMessage,
VisualizationMessage,
)
from posthog.test.base import APIBaseTest, ClickhouseTestMixin
Expand Down Expand Up @@ -169,6 +171,71 @@ def test_agent_reconstructs_conversation_and_merges_messages(self):
self.assertNotIn("{{question}}", history[5].content)
self.assertIn("Follow\nUp", history[5].content)

def test_agent_reconstructs_typical_conversation(self):
node = DummyGeneratorNode(self.team)
history = node._construct_messages(
{
"messages": [
HumanMessage(content="Question 1"),
RouterMessage(content="trends"),
VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 1"),
AssistantMessage(content="Summary 1"),
HumanMessage(content="Question 2"),
RouterMessage(content="funnel"),
VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 2"),
AssistantMessage(content="Summary 2"),
HumanMessage(content="Question 3"),
RouterMessage(content="funnel"),
],
"plan": "Plan 3",
}
)
self.assertEqual(len(history), 8)
self.assertEqual(history[0].type, "human")
self.assertIn("mapping", history[0].content)
self.assertEqual(history[1].type, "human")
self.assertIn("Plan 1", history[1].content)
self.assertEqual(history[2].type, "human")
self.assertIn("Question 1", history[2].content)
self.assertEqual(history[3].type, "human")
self.assertIn("Plan 2", history[3].content)
self.assertEqual(history[4].type, "human")
self.assertIn("Question 2", history[4].content)
self.assertEqual(history[5].type, "ai")
self.assertEqual(history[6].type, "human")
self.assertIn("Plan 3", history[6].content)
self.assertEqual(history[7].type, "human")
self.assertIn("Question 3", history[7].content)

def test_prompt(self):
node = DummyGeneratorNode(self.team)
state = {
"messages": [
HumanMessage(content="Question 1"),
RouterMessage(content="trends"),
VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 1"),
AssistantMessage(content="Summary 1"),
HumanMessage(content="Question 2"),
RouterMessage(content="funnel"),
VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 2"),
AssistantMessage(content="Summary 2"),
HumanMessage(content="Question 3"),
RouterMessage(content="funnel"),
],
"plan": "Plan 3",
}
with patch.object(DummyGeneratorNode, "_model") as generator_model_mock:

def assert_prompt(prompt):
self.assertEqual(len(prompt), 4)
self.assertEqual(prompt[0].type, "system")
self.assertEqual(prompt[1].type, "human")
self.assertEqual(prompt[2].type, "ai")
self.assertEqual(prompt[3].type, "human")

generator_model_mock.return_value = RunnableLambda(assert_prompt)
node.run(state, {})

def test_failover_with_incorrect_schema(self):
node = DummyGeneratorNode(self.team)
with patch.object(DummyGeneratorNode, "_model") as generator_model_mock:
Expand Down
31 changes: 31 additions & 0 deletions ee/hogai/taxonomy_agent/test/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
AssistantTrendsQuery,
FailureMessage,
HumanMessage,
RouterMessage,
VisualizationMessage,
)
from posthog.test.base import APIBaseTest, ClickhouseTestMixin, _create_event, _create_person
Expand Down Expand Up @@ -116,6 +117,36 @@ def test_agent_reconstructs_conversation_with_failures(self):
self.assertIn("Text", history[0].content)
self.assertNotIn("{{question}}", history[0].content)

def test_agent_reconstructs_typical_conversation(self):
node = self._get_node()
history = node._construct_messages(
{
"messages": [
HumanMessage(content="Question 1"),
RouterMessage(content="trends"),
VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 1"),
AssistantMessage(content="Summary 1"),
HumanMessage(content="Question 2"),
RouterMessage(content="funnel"),
VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 2"),
AssistantMessage(content="Summary 2"),
HumanMessage(content="Question 3"),
RouterMessage(content="funnel"),
]
}
)
self.assertEqual(len(history), 5)
self.assertEqual(history[0].type, "human")
self.assertIn("Question 1", history[0].content)
self.assertEqual(history[1].type, "ai")
self.assertEqual(history[1].content, "Plan 1")
self.assertEqual(history[2].type, "human")
self.assertIn("Question 2", history[2].content)
self.assertEqual(history[3].type, "ai")
self.assertEqual(history[3].content, "Plan 2")
self.assertEqual(history[4].type, "human")
self.assertIn("Question 3", history[4].content)

def test_agent_filters_out_low_count_events(self):
_create_person(distinct_ids=["test"], team=self.team)
for i in range(26):
Expand Down
35 changes: 34 additions & 1 deletion ee/hogai/test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from langchain_core.messages import HumanMessage as LangchainHumanMessage

from ee.hogai.utils import filter_visualization_conversation, merge_human_messages
from posthog.schema import AssistantTrendsQuery, FailureMessage, HumanMessage, VisualizationMessage
from posthog.schema import (
AssistantMessage,
AssistantTrendsQuery,
FailureMessage,
HumanMessage,
RouterMessage,
VisualizationMessage,
)
from posthog.test.base import BaseTest


Expand Down Expand Up @@ -37,3 +44,29 @@ def test_filter_trends_conversation(self):
self.assertEqual(
visualization_messages, [VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="plan")]
)

def test_filters_typical_conversation(self):
human_messages, visualization_messages = filter_visualization_conversation(
[
HumanMessage(content="Question 1"),
RouterMessage(content="trends"),
VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 1"),
AssistantMessage(content="Summary 1"),
HumanMessage(content="Question 2"),
RouterMessage(content="funnel"),
VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 2"),
AssistantMessage(content="Summary 2"),
]
)
self.assertEqual(len(human_messages), 2)
self.assertEqual(len(visualization_messages), 2)
self.assertEqual(
human_messages, [LangchainHumanMessage(content="Question 1"), LangchainHumanMessage(content="Question 2")]
)
self.assertEqual(
visualization_messages,
[
VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 1"),
VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 2"),
],
)

0 comments on commit e3ba870

Please sign in to comment.