From e3ba870704ae4d69439ac3466f081b7827067311 Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Tue, 26 Nov 2024 17:29:56 +0100 Subject: [PATCH] fix(product-assistant): trim redundant queries in the schema generator (#26425) --- ee/hogai/schema_generator/nodes.py | 9 ++- ee/hogai/schema_generator/test/test_nodes.py | 67 ++++++++++++++++++++ ee/hogai/taxonomy_agent/test/test_nodes.py | 31 +++++++++ ee/hogai/test/test_utils.py | 35 +++++++++- 4 files changed, 139 insertions(+), 3 deletions(-) diff --git a/ee/hogai/schema_generator/nodes.py b/ee/hogai/schema_generator/nodes.py index 560a9a7d5cc9e..c5e7ffbba85c4 100644 --- a/ee/hogai/schema_generator/nodes.py +++ b/ee/hogai/schema_generator/nodes.py @@ -144,7 +144,10 @@ def _construct_messages( human_messages, visualization_messages = filter_visualization_conversation(messages) first_ai_message = True - for human_message, ai_message in itertools.zip_longest(human_messages, visualization_messages): + for idx, (human_message, ai_message) in enumerate( + itertools.zip_longest(human_messages, visualization_messages) + ): + # Plans go first if ai_message: conversation.append( HumanMessagePromptTemplate.from_template( @@ -161,6 +164,7 @@ def _construct_messages( ).format(plan=generated_plan) ) + # Then questions if human_message: conversation.append( HumanMessagePromptTemplate.from_template(QUESTION_PROMPT, template_format="mustache").format( @@ -168,7 +172,8 @@ def _construct_messages( ) ) - if ai_message: + # Then schemas, but include only last generated schema because it doesn't need more context. + if ai_message and idx + 1 == len(visualization_messages): conversation.append( LangchainAssistantMessage(content=ai_message.answer.model_dump_json() if ai_message.answer else "") ) diff --git a/ee/hogai/schema_generator/test/test_nodes.py b/ee/hogai/schema_generator/test/test_nodes.py index af66234978794..795045af50b56 100644 --- a/ee/hogai/schema_generator/test/test_nodes.py +++ b/ee/hogai/schema_generator/test/test_nodes.py @@ -9,9 +9,11 @@ from ee.hogai.schema_generator.nodes import SchemaGeneratorNode, SchemaGeneratorToolsNode from ee.hogai.schema_generator.utils import SchemaGeneratorOutput from posthog.schema import ( + AssistantMessage, AssistantTrendsQuery, FailureMessage, HumanMessage, + RouterMessage, VisualizationMessage, ) from posthog.test.base import APIBaseTest, ClickhouseTestMixin @@ -169,6 +171,71 @@ def test_agent_reconstructs_conversation_and_merges_messages(self): self.assertNotIn("{{question}}", history[5].content) self.assertIn("Follow\nUp", history[5].content) + def test_agent_reconstructs_typical_conversation(self): + node = DummyGeneratorNode(self.team) + history = node._construct_messages( + { + "messages": [ + HumanMessage(content="Question 1"), + RouterMessage(content="trends"), + VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 1"), + AssistantMessage(content="Summary 1"), + HumanMessage(content="Question 2"), + RouterMessage(content="funnel"), + VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 2"), + AssistantMessage(content="Summary 2"), + HumanMessage(content="Question 3"), + RouterMessage(content="funnel"), + ], + "plan": "Plan 3", + } + ) + self.assertEqual(len(history), 8) + self.assertEqual(history[0].type, "human") + self.assertIn("mapping", history[0].content) + self.assertEqual(history[1].type, "human") + self.assertIn("Plan 1", history[1].content) + self.assertEqual(history[2].type, "human") + self.assertIn("Question 1", history[2].content) + self.assertEqual(history[3].type, "human") + self.assertIn("Plan 2", history[3].content) + self.assertEqual(history[4].type, "human") + self.assertIn("Question 2", history[4].content) + self.assertEqual(history[5].type, "ai") + self.assertEqual(history[6].type, "human") + self.assertIn("Plan 3", history[6].content) + self.assertEqual(history[7].type, "human") + self.assertIn("Question 3", history[7].content) + + def test_prompt(self): + node = DummyGeneratorNode(self.team) + state = { + "messages": [ + HumanMessage(content="Question 1"), + RouterMessage(content="trends"), + VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 1"), + AssistantMessage(content="Summary 1"), + HumanMessage(content="Question 2"), + RouterMessage(content="funnel"), + VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 2"), + AssistantMessage(content="Summary 2"), + HumanMessage(content="Question 3"), + RouterMessage(content="funnel"), + ], + "plan": "Plan 3", + } + with patch.object(DummyGeneratorNode, "_model") as generator_model_mock: + + def assert_prompt(prompt): + self.assertEqual(len(prompt), 4) + self.assertEqual(prompt[0].type, "system") + self.assertEqual(prompt[1].type, "human") + self.assertEqual(prompt[2].type, "ai") + self.assertEqual(prompt[3].type, "human") + + generator_model_mock.return_value = RunnableLambda(assert_prompt) + node.run(state, {}) + def test_failover_with_incorrect_schema(self): node = DummyGeneratorNode(self.team) with patch.object(DummyGeneratorNode, "_model") as generator_model_mock: diff --git a/ee/hogai/taxonomy_agent/test/test_nodes.py b/ee/hogai/taxonomy_agent/test/test_nodes.py index fe3d52266ec18..40127c19370b6 100644 --- a/ee/hogai/taxonomy_agent/test/test_nodes.py +++ b/ee/hogai/taxonomy_agent/test/test_nodes.py @@ -18,6 +18,7 @@ AssistantTrendsQuery, FailureMessage, HumanMessage, + RouterMessage, VisualizationMessage, ) from posthog.test.base import APIBaseTest, ClickhouseTestMixin, _create_event, _create_person @@ -116,6 +117,36 @@ def test_agent_reconstructs_conversation_with_failures(self): self.assertIn("Text", history[0].content) self.assertNotIn("{{question}}", history[0].content) + def test_agent_reconstructs_typical_conversation(self): + node = self._get_node() + history = node._construct_messages( + { + "messages": [ + HumanMessage(content="Question 1"), + RouterMessage(content="trends"), + VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 1"), + AssistantMessage(content="Summary 1"), + HumanMessage(content="Question 2"), + RouterMessage(content="funnel"), + VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 2"), + AssistantMessage(content="Summary 2"), + HumanMessage(content="Question 3"), + RouterMessage(content="funnel"), + ] + } + ) + self.assertEqual(len(history), 5) + self.assertEqual(history[0].type, "human") + self.assertIn("Question 1", history[0].content) + self.assertEqual(history[1].type, "ai") + self.assertEqual(history[1].content, "Plan 1") + self.assertEqual(history[2].type, "human") + self.assertIn("Question 2", history[2].content) + self.assertEqual(history[3].type, "ai") + self.assertEqual(history[3].content, "Plan 2") + self.assertEqual(history[4].type, "human") + self.assertIn("Question 3", history[4].content) + def test_agent_filters_out_low_count_events(self): _create_person(distinct_ids=["test"], team=self.team) for i in range(26): diff --git a/ee/hogai/test/test_utils.py b/ee/hogai/test/test_utils.py index 89f23d2fdd7b6..42e54d058c556 100644 --- a/ee/hogai/test/test_utils.py +++ b/ee/hogai/test/test_utils.py @@ -1,7 +1,14 @@ from langchain_core.messages import HumanMessage as LangchainHumanMessage from ee.hogai.utils import filter_visualization_conversation, merge_human_messages -from posthog.schema import AssistantTrendsQuery, FailureMessage, HumanMessage, VisualizationMessage +from posthog.schema import ( + AssistantMessage, + AssistantTrendsQuery, + FailureMessage, + HumanMessage, + RouterMessage, + VisualizationMessage, +) from posthog.test.base import BaseTest @@ -37,3 +44,29 @@ def test_filter_trends_conversation(self): self.assertEqual( visualization_messages, [VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="plan")] ) + + def test_filters_typical_conversation(self): + human_messages, visualization_messages = filter_visualization_conversation( + [ + HumanMessage(content="Question 1"), + RouterMessage(content="trends"), + VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 1"), + AssistantMessage(content="Summary 1"), + HumanMessage(content="Question 2"), + RouterMessage(content="funnel"), + VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 2"), + AssistantMessage(content="Summary 2"), + ] + ) + self.assertEqual(len(human_messages), 2) + self.assertEqual(len(visualization_messages), 2) + self.assertEqual( + human_messages, [LangchainHumanMessage(content="Question 1"), LangchainHumanMessage(content="Question 2")] + ) + self.assertEqual( + visualization_messages, + [ + VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 1"), + VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 2"), + ], + )