-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
149 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from langchain_core.messages import HumanMessage as LangchainHumanMessage | ||
|
||
from ee.hogai.trends import utils | ||
from posthog.schema import ExperimentalAITrendsQuery, FailureMessage, HumanMessage, VisualizationMessage | ||
from posthog.test.base import BaseTest | ||
|
||
|
||
class TestTrendsUtils(BaseTest): | ||
def test_merge_human_messages(self): | ||
res = utils.merge_human_messages( | ||
[ | ||
LangchainHumanMessage(content="Text"), | ||
LangchainHumanMessage(content="Text"), | ||
LangchainHumanMessage(content="Te"), | ||
LangchainHumanMessage(content="xt"), | ||
] | ||
) | ||
self.assertEqual(len(res), 1) | ||
self.assertEqual(res, [LangchainHumanMessage(content="Text\nTe\nxt")]) | ||
|
||
def test_filter_trends_conversation(self): | ||
human_messages, visualization_messages = utils.filter_trends_conversation( | ||
[ | ||
HumanMessage(content="Text"), | ||
FailureMessage(content="Error"), | ||
HumanMessage(content="Text"), | ||
VisualizationMessage(answer=ExperimentalAITrendsQuery(series=[]), plan="plan"), | ||
HumanMessage(content="Text2"), | ||
VisualizationMessage(answer=None, plan="plan"), | ||
] | ||
) | ||
self.assertEqual(len(human_messages), 2) | ||
self.assertEqual(len(visualization_messages), 1) | ||
self.assertEqual( | ||
human_messages, [LangchainHumanMessage(content="Text"), LangchainHumanMessage(content="Text2")] | ||
) | ||
self.assertEqual( | ||
visualization_messages, [VisualizationMessage(answer=ExperimentalAITrendsQuery(series=[]), plan="plan")] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,53 @@ | ||
from collections.abc import Sequence | ||
from typing import Optional | ||
|
||
from langchain_core.messages import HumanMessage as LangchainHumanMessage | ||
from langchain_core.messages import merge_message_runs | ||
from pydantic import BaseModel | ||
|
||
from posthog.schema import ExperimentalAITrendsQuery | ||
from ee.hogai.utils import AssistantMessageUnion | ||
from posthog.schema import ExperimentalAITrendsQuery, HumanMessage, VisualizationMessage | ||
|
||
|
||
class GenerateTrendOutputModel(BaseModel): | ||
reasoning_steps: Optional[list[str]] = None | ||
answer: Optional[ExperimentalAITrendsQuery] = None | ||
|
||
|
||
def merge_human_messages(messages: list[LangchainHumanMessage]) -> list[LangchainHumanMessage]: | ||
""" | ||
Filters out duplicated human messages and merges them into one message. | ||
""" | ||
contents = set() | ||
filtered_messages = [] | ||
for message in messages: | ||
if message.content in contents: | ||
continue | ||
contents.add(message.content) | ||
filtered_messages.append(message) | ||
return merge_message_runs(filtered_messages) | ||
|
||
|
||
def filter_trends_conversation( | ||
messages: Sequence[AssistantMessageUnion], | ||
) -> tuple[list[LangchainHumanMessage], list[VisualizationMessage]]: | ||
""" | ||
Splits, filters and merges the message history to be consumable by agents. | ||
""" | ||
stack: list[LangchainHumanMessage] = [] | ||
human_messages: list[LangchainHumanMessage] = [] | ||
visualization_messages: list[VisualizationMessage] = [] | ||
|
||
for message in messages: | ||
if isinstance(message, HumanMessage): | ||
stack.append(LangchainHumanMessage(content=message.content)) | ||
elif isinstance(message, VisualizationMessage) and message.answer: | ||
if stack: | ||
human_messages += merge_human_messages(stack) | ||
stack = [] | ||
visualization_messages.append(message) | ||
|
||
if stack: | ||
human_messages += merge_human_messages(stack) | ||
|
||
return human_messages, visualization_messages |