diff --git a/ee/hogai/trends/nodes.py b/ee/hogai/trends/nodes.py index e72af9b86290d..206b74173a192 100644 --- a/ee/hogai/trends/nodes.py +++ b/ee/hogai/trends/nodes.py @@ -9,7 +9,6 @@ from langchain_core.exceptions import OutputParserException from langchain_core.messages import AIMessage as LangchainAssistantMessage from langchain_core.messages import BaseMessage, merge_message_runs -from langchain_core.messages import HumanMessage as LangchainHumanMessage from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate from langchain_core.runnables import RunnableConfig from langchain_openai import ChatOpenAI @@ -36,7 +35,7 @@ TrendsAgentToolkit, TrendsAgentToolModel, ) -from ee.hogai.trends.utils import GenerateTrendOutputModel +from ee.hogai.trends.utils import GenerateTrendOutputModel, filter_trends_conversation from ee.hogai.utils import ( AssistantNode, AssistantNodeName, @@ -49,7 +48,6 @@ from posthog.schema import ( CachedTeamTaxonomyQueryResponse, FailureMessage, - HumanMessage, TeamTaxonomyQuery, VisualizationMessage, ) @@ -177,26 +175,33 @@ def _reconstruct_conversation(self, state: AssistantState) -> list[BaseMessage]: """ Reconstruct the conversation for the agent. On this step we only care about previously asked questions and generated plans. All other messages are filtered out. """ - messages = state.get("messages", []) - if len(messages) == 0: + human_messages, visualization_messages = filter_trends_conversation(state.get("messages", [])) + + if not human_messages: return [] - conversation = [ - HumanMessagePromptTemplate.from_template(react_user_prompt, template_format="mustache").format( - question=messages[0].content if isinstance(messages[0], HumanMessage) else "" - ) - ] + conversation = [] - for message in messages[1:]: - if isinstance(message, HumanMessage): - conversation.append( - HumanMessagePromptTemplate.from_template( - react_follow_up_prompt, - template_format="mustache", - ).format(feedback=message.content) - ) - elif isinstance(message, VisualizationMessage): - conversation.append(LangchainAssistantMessage(content=message.plan or "")) + for idx, messages in enumerate(itertools.zip_longest(human_messages, visualization_messages)): + human_message, viz_message = messages + + if human_message: + if idx == 0: + conversation.append( + HumanMessagePromptTemplate.from_template(react_user_prompt, template_format="mustache").format( + question=human_message.content + ) + ) + else: + conversation.append( + HumanMessagePromptTemplate.from_template( + react_follow_up_prompt, + template_format="mustache", + ).format(feedback=human_message.content) + ) + + if viz_message: + conversation.append(LangchainAssistantMessage(content=viz_message.plan or "")) return conversation @@ -335,22 +340,7 @@ def _reconstruct_conversation( ) ] - stack: list[LangchainHumanMessage] = [] - human_messages: list[LangchainHumanMessage] = [] - visualization_messages: list[VisualizationMessage] = [] - - for message in messages: - if isinstance(message, HumanMessage): - stack.append(LangchainHumanMessage(content=message.content)) - elif isinstance(message, VisualizationMessage) and message.answer: - if stack: - human_messages += merge_message_runs(stack) - stack = [] - visualization_messages.append(message) - - if stack: - human_messages += merge_message_runs(stack) - + human_messages, visualization_messages = filter_trends_conversation(messages) first_ai_message = True for human_message, ai_message in itertools.zip_longest(human_messages, visualization_messages): diff --git a/ee/hogai/trends/test/test_nodes.py b/ee/hogai/trends/test/test_nodes.py index 57f2ae8358fe8..990f83ddca754 100644 --- a/ee/hogai/trends/test/test_nodes.py +++ b/ee/hogai/trends/test/test_nodes.py @@ -81,6 +81,22 @@ def test_agent_reconstructs_conversation_and_omits_unknown_messages(self): self.assertIn("Text", history[0].content) self.assertNotIn("{{question}}", history[0].content) + def test_agent_reconstructs_conversation_with_failures(self): + node = CreateTrendsPlanNode(self.team) + history = node._reconstruct_conversation( + { + "messages": [ + HumanMessage(content="Text"), + FailureMessage(content="Error"), + HumanMessage(content="Text"), + ] + } + ) + self.assertEqual(len(history), 1) + self.assertEqual(history[0].type, "human") + self.assertIn("Text", history[0].content) + self.assertNotIn("{{question}}", history[0].content) + def test_agent_filters_out_low_count_events(self): _create_person(distinct_ids=["test"], team=self.team) for i in range(26): @@ -337,6 +353,30 @@ def test_agent_reconstructs_conversation_with_failover(self): self.assertIn("Pydantic", history[3].content) self.assertIn("uniqexception", history[3].content) + def test_agent_reconstructs_conversation_with_failed_messages(self): + node = GenerateTrendsNode(self.team) + history = node._reconstruct_conversation( + { + "messages": [ + HumanMessage(content="Text"), + FailureMessage(content="Error"), + HumanMessage(content="Text"), + ], + "plan": "randomplan", + }, + ) + self.assertEqual(len(history), 3) + self.assertEqual(history[0].type, "human") + self.assertIn("mapping", history[0].content) + self.assertEqual(history[1].type, "human") + self.assertIn("the plan", history[1].content) + self.assertNotIn("{{plan}}", history[1].content) + self.assertIn("randomplan", history[1].content) + self.assertEqual(history[2].type, "human") + self.assertIn("Answer to this question:", history[2].content) + self.assertNotIn("{{question}}", history[2].content) + self.assertIn("Text", history[2].content) + def test_router(self): node = GenerateTrendsNode(self.team) state = node.router({"messages": [], "intermediate_steps": None}) diff --git a/ee/hogai/trends/test/test_utils.py b/ee/hogai/trends/test/test_utils.py new file mode 100644 index 0000000000000..de9b8733129ec --- /dev/null +++ b/ee/hogai/trends/test/test_utils.py @@ -0,0 +1,39 @@ +from langchain_core.messages import HumanMessage as LangchainHumanMessage + +from ee.hogai.trends import utils +from posthog.schema import ExperimentalAITrendsQuery, FailureMessage, HumanMessage, VisualizationMessage +from posthog.test.base import BaseTest + + +class TestTrendsUtils(BaseTest): + def test_merge_human_messages(self): + res = utils.merge_human_messages( + [ + LangchainHumanMessage(content="Text"), + LangchainHumanMessage(content="Text"), + LangchainHumanMessage(content="Te"), + LangchainHumanMessage(content="xt"), + ] + ) + self.assertEqual(len(res), 1) + self.assertEqual(res, [LangchainHumanMessage(content="Text\nTe\nxt")]) + + def test_filter_trends_conversation(self): + human_messages, visualization_messages = utils.filter_trends_conversation( + [ + HumanMessage(content="Text"), + FailureMessage(content="Error"), + HumanMessage(content="Text"), + VisualizationMessage(answer=ExperimentalAITrendsQuery(series=[]), plan="plan"), + HumanMessage(content="Text2"), + VisualizationMessage(answer=None, plan="plan"), + ] + ) + self.assertEqual(len(human_messages), 2) + self.assertEqual(len(visualization_messages), 1) + self.assertEqual( + human_messages, [LangchainHumanMessage(content="Text"), LangchainHumanMessage(content="Text2")] + ) + self.assertEqual( + visualization_messages, [VisualizationMessage(answer=ExperimentalAITrendsQuery(series=[]), plan="plan")] + ) diff --git a/ee/hogai/trends/utils.py b/ee/hogai/trends/utils.py index 080f85f0256d0..5e1a8052707c8 100644 --- a/ee/hogai/trends/utils.py +++ b/ee/hogai/trends/utils.py @@ -1,10 +1,53 @@ +from collections.abc import Sequence from typing import Optional +from langchain_core.messages import HumanMessage as LangchainHumanMessage +from langchain_core.messages import merge_message_runs from pydantic import BaseModel -from posthog.schema import ExperimentalAITrendsQuery +from ee.hogai.utils import AssistantMessageUnion +from posthog.schema import ExperimentalAITrendsQuery, HumanMessage, VisualizationMessage class GenerateTrendOutputModel(BaseModel): reasoning_steps: Optional[list[str]] = None answer: Optional[ExperimentalAITrendsQuery] = None + + +def merge_human_messages(messages: list[LangchainHumanMessage]) -> list[LangchainHumanMessage]: + """ + Filters out duplicated human messages and merges them into one message. + """ + contents = set() + filtered_messages = [] + for message in messages: + if message.content in contents: + continue + contents.add(message.content) + filtered_messages.append(message) + return merge_message_runs(filtered_messages) + + +def filter_trends_conversation( + messages: Sequence[AssistantMessageUnion], +) -> tuple[list[LangchainHumanMessage], list[VisualizationMessage]]: + """ + Splits, filters and merges the message history to be consumable by agents. + """ + stack: list[LangchainHumanMessage] = [] + human_messages: list[LangchainHumanMessage] = [] + visualization_messages: list[VisualizationMessage] = [] + + for message in messages: + if isinstance(message, HumanMessage): + stack.append(LangchainHumanMessage(content=message.content)) + elif isinstance(message, VisualizationMessage) and message.answer: + if stack: + human_messages += merge_human_messages(stack) + stack = [] + visualization_messages.append(message) + + if stack: + human_messages += merge_human_messages(stack) + + return human_messages, visualization_messages