Skip to content

Commit

Permalink
test: merging failures
Browse files Browse the repository at this point in the history
  • Loading branch information
skoob13 committed Oct 29, 2024
1 parent 8b2497e commit a382bfa
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 37 deletions.
62 changes: 26 additions & 36 deletions ee/hogai/trends/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage as LangchainAssistantMessage
from langchain_core.messages import BaseMessage, merge_message_runs
from langchain_core.messages import HumanMessage as LangchainHumanMessage
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_core.runnables import RunnableConfig
from langchain_openai import ChatOpenAI
Expand All @@ -36,7 +35,7 @@
TrendsAgentToolkit,
TrendsAgentToolModel,
)
from ee.hogai.trends.utils import GenerateTrendOutputModel
from ee.hogai.trends.utils import GenerateTrendOutputModel, filter_trends_conversation
from ee.hogai.utils import (
AssistantNode,
AssistantNodeName,
Expand All @@ -49,7 +48,6 @@
from posthog.schema import (
CachedTeamTaxonomyQueryResponse,
FailureMessage,
HumanMessage,
TeamTaxonomyQuery,
VisualizationMessage,
)
Expand Down Expand Up @@ -177,26 +175,33 @@ def _reconstruct_conversation(self, state: AssistantState) -> list[BaseMessage]:
"""
Reconstruct the conversation for the agent. On this step we only care about previously asked questions and generated plans. All other messages are filtered out.
"""
messages = state.get("messages", [])
if len(messages) == 0:
human_messages, visualization_messages = filter_trends_conversation(state.get("messages", []))

if not human_messages:
return []

conversation = [
HumanMessagePromptTemplate.from_template(react_user_prompt, template_format="mustache").format(
question=messages[0].content if isinstance(messages[0], HumanMessage) else ""
)
]
conversation = []

for message in messages[1:]:
if isinstance(message, HumanMessage):
conversation.append(
HumanMessagePromptTemplate.from_template(
react_follow_up_prompt,
template_format="mustache",
).format(feedback=message.content)
)
elif isinstance(message, VisualizationMessage):
conversation.append(LangchainAssistantMessage(content=message.plan or ""))
for idx, messages in enumerate(itertools.zip_longest(human_messages, visualization_messages)):
human_message, viz_message = messages

if human_message:
if idx == 0:
conversation.append(
HumanMessagePromptTemplate.from_template(react_user_prompt, template_format="mustache").format(
question=human_message.content
)
)
else:
conversation.append(
HumanMessagePromptTemplate.from_template(
react_follow_up_prompt,
template_format="mustache",
).format(feedback=human_message.content)
)

if viz_message:
conversation.append(LangchainAssistantMessage(content=viz_message.plan or ""))

return conversation

Expand Down Expand Up @@ -335,22 +340,7 @@ def _reconstruct_conversation(
)
]

stack: list[LangchainHumanMessage] = []
human_messages: list[LangchainHumanMessage] = []
visualization_messages: list[VisualizationMessage] = []

for message in messages:
if isinstance(message, HumanMessage):
stack.append(LangchainHumanMessage(content=message.content))
elif isinstance(message, VisualizationMessage) and message.answer:
if stack:
human_messages += merge_message_runs(stack)
stack = []
visualization_messages.append(message)

if stack:
human_messages += merge_message_runs(stack)

human_messages, visualization_messages = filter_trends_conversation(messages)
first_ai_message = True

for human_message, ai_message in itertools.zip_longest(human_messages, visualization_messages):
Expand Down
40 changes: 40 additions & 0 deletions ee/hogai/trends/test/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,22 @@ def test_agent_reconstructs_conversation_and_omits_unknown_messages(self):
self.assertIn("Text", history[0].content)
self.assertNotIn("{{question}}", history[0].content)

def test_agent_reconstructs_conversation_with_failures(self):
node = CreateTrendsPlanNode(self.team)
history = node._reconstruct_conversation(
{
"messages": [
HumanMessage(content="Text"),
FailureMessage(content="Error"),
HumanMessage(content="Text"),
]
}
)
self.assertEqual(len(history), 1)
self.assertEqual(history[0].type, "human")
self.assertIn("Text", history[0].content)
self.assertNotIn("{{question}}", history[0].content)

def test_agent_filters_out_low_count_events(self):
_create_person(distinct_ids=["test"], team=self.team)
for i in range(26):
Expand Down Expand Up @@ -337,6 +353,30 @@ def test_agent_reconstructs_conversation_with_failover(self):
self.assertIn("Pydantic", history[3].content)
self.assertIn("uniqexception", history[3].content)

def test_agent_reconstructs_conversation_with_failed_messages(self):
node = GenerateTrendsNode(self.team)
history = node._reconstruct_conversation(
{
"messages": [
HumanMessage(content="Text"),
FailureMessage(content="Error"),
HumanMessage(content="Text"),
],
"plan": "randomplan",
},
)
self.assertEqual(len(history), 3)
self.assertEqual(history[0].type, "human")
self.assertIn("mapping", history[0].content)
self.assertEqual(history[1].type, "human")
self.assertIn("the plan", history[1].content)
self.assertNotIn("{{plan}}", history[1].content)
self.assertIn("randomplan", history[1].content)
self.assertEqual(history[2].type, "human")
self.assertIn("Answer to this question:", history[2].content)
self.assertNotIn("{{question}}", history[2].content)
self.assertIn("Text", history[2].content)

def test_router(self):
node = GenerateTrendsNode(self.team)
state = node.router({"messages": [], "intermediate_steps": None})
Expand Down
39 changes: 39 additions & 0 deletions ee/hogai/trends/test/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from langchain_core.messages import HumanMessage as LangchainHumanMessage

from ee.hogai.trends import utils
from posthog.schema import ExperimentalAITrendsQuery, FailureMessage, HumanMessage, VisualizationMessage
from posthog.test.base import BaseTest


class TestTrendsUtils(BaseTest):
def test_merge_human_messages(self):
res = utils.merge_human_messages(
[
LangchainHumanMessage(content="Text"),
LangchainHumanMessage(content="Text"),
LangchainHumanMessage(content="Te"),
LangchainHumanMessage(content="xt"),
]
)
self.assertEqual(len(res), 1)
self.assertEqual(res, [LangchainHumanMessage(content="Text\nTe\nxt")])

def test_filter_trends_conversation(self):
human_messages, visualization_messages = utils.filter_trends_conversation(
[
HumanMessage(content="Text"),
FailureMessage(content="Error"),
HumanMessage(content="Text"),
VisualizationMessage(answer=ExperimentalAITrendsQuery(series=[]), plan="plan"),
HumanMessage(content="Text2"),
VisualizationMessage(answer=None, plan="plan"),
]
)
self.assertEqual(len(human_messages), 2)
self.assertEqual(len(visualization_messages), 1)
self.assertEqual(
human_messages, [LangchainHumanMessage(content="Text"), LangchainHumanMessage(content="Text2")]
)
self.assertEqual(
visualization_messages, [VisualizationMessage(answer=ExperimentalAITrendsQuery(series=[]), plan="plan")]
)
45 changes: 44 additions & 1 deletion ee/hogai/trends/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,53 @@
from collections.abc import Sequence
from typing import Optional

from langchain_core.messages import HumanMessage as LangchainHumanMessage
from langchain_core.messages import merge_message_runs
from pydantic import BaseModel

from posthog.schema import ExperimentalAITrendsQuery
from ee.hogai.utils import AssistantMessageUnion
from posthog.schema import ExperimentalAITrendsQuery, HumanMessage, VisualizationMessage


class GenerateTrendOutputModel(BaseModel):
reasoning_steps: Optional[list[str]] = None
answer: Optional[ExperimentalAITrendsQuery] = None


def merge_human_messages(messages: list[LangchainHumanMessage]) -> list[LangchainHumanMessage]:
"""
Filters out duplicated human messages and merges them into one message.
"""
contents = set()
filtered_messages = []
for message in messages:
if message.content in contents:
continue
contents.add(message.content)
filtered_messages.append(message)
return merge_message_runs(filtered_messages)


def filter_trends_conversation(
messages: Sequence[AssistantMessageUnion],
) -> tuple[list[LangchainHumanMessage], list[VisualizationMessage]]:
"""
Splits, filters and merges the message history to be consumable by agents.
"""
stack: list[LangchainHumanMessage] = []
human_messages: list[LangchainHumanMessage] = []
visualization_messages: list[VisualizationMessage] = []

for message in messages:
if isinstance(message, HumanMessage):
stack.append(LangchainHumanMessage(content=message.content))
elif isinstance(message, VisualizationMessage) and message.answer:
if stack:
human_messages += merge_human_messages(stack)
stack = []
visualization_messages.append(message)

if stack:
human_messages += merge_human_messages(stack)

return human_messages, visualization_messages

0 comments on commit a382bfa

Please sign in to comment.