Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(product-assistant): trends generation failover #25769

Merged
merged 28 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
89a1579
feat: fail over
skoob13 Oct 25, 2024
680ea43
fix: fallback streaming
skoob13 Oct 23, 2024
c19028b
test: assistant test
skoob13 Oct 25, 2024
88409ee
Update query snapshots
github-actions[bot] Oct 25, 2024
30a5cc6
Update query snapshots
github-actions[bot] Oct 25, 2024
d3fb521
Merge branch 'feat/trends-generation-failover' of github.com:PostHog/…
skoob13 Oct 28, 2024
e360d47
Merge branch 'master' of github.com:PostHog/posthog into feat/trends-…
skoob13 Oct 28, 2024
067db1f
fix: test
skoob13 Oct 28, 2024
b8b9f38
fix: tests
skoob13 Oct 28, 2024
800e64b
test: more tests
skoob13 Oct 28, 2024
c2efd71
Merge branch 'master' of github.com:PostHog/posthog into feat/trends-…
skoob13 Oct 29, 2024
2ceb484
feat: improved validation message
skoob13 Oct 29, 2024
923768b
feat: status messages
skoob13 Oct 29, 2024
ad7e07c
test: failover
skoob13 Oct 29, 2024
40a1e1c
feat: frontend messages for repeated generations
skoob13 Oct 29, 2024
8b2497e
feat: retry generation
skoob13 Oct 29, 2024
a382bfa
test: merging failures
skoob13 Oct 29, 2024
d162c45
fix: padding for the thread
skoob13 Oct 29, 2024
71a914d
fix: keep message status
skoob13 Oct 29, 2024
fd157fb
Update UI snapshots for `chromium` (1)
github-actions[bot] Oct 29, 2024
0657e1a
feat: handle network/parsing errors
skoob13 Oct 29, 2024
86b472a
Merge branch 'feat/trends-generation-failover' of github.com:PostHog/…
skoob13 Oct 29, 2024
9862fcb
Merge branch 'master' of github.com:PostHog/posthog into feat/trends-…
skoob13 Oct 29, 2024
a14457a
Update query snapshots
github-actions[bot] Oct 29, 2024
56cb58d
Update query snapshots
github-actions[bot] Oct 29, 2024
4bfde89
chore: code style
skoob13 Oct 30, 2024
197c0f0
Merge branch 'master' of github.com:PostHog/posthog into feat/trends-…
skoob13 Oct 30, 2024
f4bbbb2
Update query snapshots
github-actions[bot] Oct 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 45 additions & 14 deletions ee/hogai/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,18 @@
from langchain_core.messages import AIMessageChunk
from langfuse.callback import CallbackHandler
from langgraph.graph.state import StateGraph
from pydantic import BaseModel

from ee import settings
from ee.hogai.trends.nodes import CreateTrendsPlanNode, CreateTrendsPlanToolsNode, GenerateTrendsNode
from ee.hogai.trends.nodes import (
CreateTrendsPlanNode,
CreateTrendsPlanToolsNode,
GenerateTrendsNode,
GenerateTrendsToolsNode,
)
from ee.hogai.utils import AssistantNodeName, AssistantState, Conversation
from posthog.models.team.team import Team
from posthog.schema import VisualizationMessage
from posthog.schema import AssistantGenerationStatusEvent, AssistantGenerationStatusType, VisualizationMessage

if settings.LANGFUSE_PUBLIC_KEY:
langfuse_handler = CallbackHandler(
Expand Down Expand Up @@ -39,6 +45,13 @@ def is_message_update(
return len(update) == 2 and update[0] == "messages"


def is_state_update(update: list[Any]) -> TypeGuard[tuple[Literal["updates"], AssistantState]]:
"""
Update of the state.
"""
return len(update) == 2 and update[0] == "values"


class Assistant:
_team: Team
_graph: StateGraph
Expand All @@ -59,38 +72,56 @@ def _compile_graph(self):
generate_trends_node = GenerateTrendsNode(self._team)
builder.add_node(GenerateTrendsNode.name, generate_trends_node.run)

generate_trends_tools_node = GenerateTrendsToolsNode(self._team)
builder.add_node(GenerateTrendsToolsNode.name, generate_trends_tools_node.run)
builder.add_edge(GenerateTrendsToolsNode.name, GenerateTrendsNode.name)

builder.add_edge(AssistantNodeName.START, create_trends_plan_node.name)
builder.add_conditional_edges(create_trends_plan_node.name, create_trends_plan_node.router)
builder.add_conditional_edges(create_trends_plan_tools_node.name, create_trends_plan_tools_node.router)
builder.add_conditional_edges(GenerateTrendsNode.name, generate_trends_node.router)

return builder.compile()

def stream(self, conversation: Conversation) -> Generator[str, None, None]:
def stream(self, conversation: Conversation) -> Generator[BaseModel, None, None]:
assistant_graph = self._compile_graph()
callbacks = [langfuse_handler] if langfuse_handler else []
messages = [message.root for message in conversation.messages]

chunks = AIMessageChunk(content="")
state: AssistantState = {"messages": messages, "intermediate_steps": None, "plan": None}

generator = assistant_graph.stream(
{"messages": messages},
state,
config={"recursion_limit": 24, "callbacks": callbacks},
stream_mode=["messages", "updates"],
stream_mode=["messages", "values", "updates"],
)

chunks = AIMessageChunk(content="")

# Send a chunk to establish the connection avoiding the worker's timeout.
yield ""
yield AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.ACK)

for update in generator:
if is_value_update(update):
if is_state_update(update):
_, new_state = update
state = new_state

elif is_value_update(update):
_, state_update = update
if (
AssistantNodeName.GENERATE_TRENDS in state_update
and "messages" in state_update[AssistantNodeName.GENERATE_TRENDS]
):
message = cast(VisualizationMessage, state_update[AssistantNodeName.GENERATE_TRENDS]["messages"][0])
yield message.model_dump_json()

if AssistantNodeName.GENERATE_TRENDS in state_update:
# Reset chunks when schema validation fails.
chunks = AIMessageChunk(content="")

if "messages" in state_update[AssistantNodeName.GENERATE_TRENDS]:
message = cast(
VisualizationMessage, state_update[AssistantNodeName.GENERATE_TRENDS]["messages"][0]
)
yield message
elif state_update[AssistantNodeName.GENERATE_TRENDS].get("intermediate_steps", []):
yield AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.GENERATION_ERROR)

elif is_message_update(update):
langchain_message, langgraph_state = update[1]
if langgraph_state["langgraph_node"] == AssistantNodeName.GENERATE_TRENDS and isinstance(
Expand All @@ -101,4 +132,4 @@ def stream(self, conversation: Conversation) -> Generator[str, None, None]:
if parsed_message:
yield VisualizationMessage(
reasoning_steps=parsed_message.reasoning_steps, answer=parsed_message.answer
).model_dump_json()
)
143 changes: 87 additions & 56 deletions ee/hogai/trends/nodes.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
import itertools
import json
import xml.etree.ElementTree as ET
from functools import cached_property
from typing import Union, cast
from typing import Optional, Union, cast

from langchain.agents.format_scratchpad import format_log_to_str
from langchain.agents.output_parsers import ReActJsonSingleInputOutputParser
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage as LangchainAssistantMessage
from langchain_core.messages import BaseMessage, merge_message_runs
from langchain_core.messages import HumanMessage as LangchainHumanMessage
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.runnables import RunnableConfig
from langchain_openai import ChatOpenAI
from pydantic import ValidationError

from ee.hogai.hardcoded_definitions import hardcoded_prop_defs
from ee.hogai.trends.parsers import PydanticOutputParserException, parse_generated_trends_output
from ee.hogai.trends.prompts import (
react_definitions_prompt,
react_follow_up_prompt,
react_scratchpad_prompt,
react_system_prompt,
react_user_prompt,
trends_failover_output_prompt,
trends_failover_prompt,
trends_group_mapping_prompt,
trends_new_plan_prompt,
trends_plan_prompt,
Expand All @@ -35,7 +35,7 @@
TrendsAgentToolkit,
TrendsAgentToolModel,
)
from ee.hogai.trends.utils import GenerateTrendOutputModel
from ee.hogai.trends.utils import GenerateTrendOutputModel, filter_trends_conversation
from ee.hogai.utils import (
AssistantNode,
AssistantNodeName,
Expand All @@ -45,7 +45,12 @@
from posthog.hogql_queries.ai.team_taxonomy_query_runner import TeamTaxonomyQueryRunner
from posthog.hogql_queries.query_runner import ExecutionMode
from posthog.models.group_type_mapping import GroupTypeMapping
from posthog.schema import CachedTeamTaxonomyQueryResponse, HumanMessage, TeamTaxonomyQuery, VisualizationMessage
from posthog.schema import (
CachedTeamTaxonomyQueryResponse,
FailureMessage,
TeamTaxonomyQuery,
VisualizationMessage,
)


class CreateTrendsPlanNode(AssistantNode):
Expand Down Expand Up @@ -170,26 +175,33 @@ def _reconstruct_conversation(self, state: AssistantState) -> list[BaseMessage]:
"""
Reconstruct the conversation for the agent. On this step we only care about previously asked questions and generated plans. All other messages are filtered out.
"""
messages = state.get("messages", [])
if len(messages) == 0:
human_messages, visualization_messages = filter_trends_conversation(state.get("messages", []))

if not human_messages:
return []

conversation = [
HumanMessagePromptTemplate.from_template(react_user_prompt, template_format="mustache").format(
question=messages[0].content if isinstance(messages[0], HumanMessage) else ""
)
]
conversation = []

for message in messages[1:]:
if isinstance(message, HumanMessage):
conversation.append(
HumanMessagePromptTemplate.from_template(
react_follow_up_prompt,
template_format="mustache",
).format(feedback=message.content)
)
elif isinstance(message, VisualizationMessage):
conversation.append(LangchainAssistantMessage(content=message.plan or ""))
for idx, messages in enumerate(itertools.zip_longest(human_messages, visualization_messages)):
human_message, viz_message = messages

if human_message:
if idx == 0:
conversation.append(
HumanMessagePromptTemplate.from_template(react_user_prompt, template_format="mustache").format(
question=human_message.content
)
)
else:
conversation.append(
HumanMessagePromptTemplate.from_template(
react_follow_up_prompt,
template_format="mustache",
).format(feedback=human_message.content)
)

if viz_message:
conversation.append(LangchainAssistantMessage(content=viz_message.plan or ""))

return conversation

Expand Down Expand Up @@ -240,30 +252,38 @@ class GenerateTrendsNode(AssistantNode):

def run(self, state: AssistantState, config: RunnableConfig):
generated_plan = state.get("plan", "")
intermediate_steps = state.get("intermediate_steps") or []
validation_error_message = intermediate_steps[-1][1] if intermediate_steps else None

trends_generation_prompt = ChatPromptTemplate.from_messages(
[
("system", trends_system_prompt),
],
template_format="mustache",
) + self._reconstruct_conversation(state)
) + self._reconstruct_conversation(state, validation_error_message=validation_error_message)
merger = merge_message_runs()

chain = (
trends_generation_prompt
| merger
| self._model
# Result from structured output is a parsed dict. Convert to a string since the output parser expects it.
| RunnableLambda(lambda x: json.dumps(x))
# Validate a string input.
| PydanticOutputParser[GenerateTrendOutputModel](pydantic_object=GenerateTrendOutputModel)
)
chain = trends_generation_prompt | merger | self._model | parse_generated_trends_output

try:
message: GenerateTrendOutputModel = chain.invoke({}, config)
except OutputParserException:
except PydanticOutputParserException as e:
# Generation step is expensive. After a second unsuccessful attempt, it's better to send a failure message.
if len(intermediate_steps) >= 2:
return {
"messages": [
FailureMessage(
content="Oops! It looks like I’m having trouble generating this trends insight. Could you please try again?"
)
],
"intermediate_steps": None,
}

return {
"messages": [VisualizationMessage(plan=generated_plan, reasoning_steps=["Schema validation failed"])]
"intermediate_steps": [
*intermediate_steps,
(AgentAction("handle_incorrect_response", e.llm_output, e.validation_message), None),
],
}

return {
Expand All @@ -273,11 +293,12 @@ def run(self, state: AssistantState, config: RunnableConfig):
reasoning_steps=message.reasoning_steps,
answer=message.answer,
)
]
],
"intermediate_steps": None,
}

def router(self, state: AssistantState):
if state.get("tool_argument") is not None:
if state.get("intermediate_steps") is not None:
return AssistantNodeName.GENERATE_TRENDS_TOOLS
return AssistantNodeName.END

Expand All @@ -301,7 +322,9 @@ def _group_mapping_prompt(self) -> str:
)
return ET.tostring(root, encoding="unicode")

def _reconstruct_conversation(self, state: AssistantState) -> list[BaseMessage]:
def _reconstruct_conversation(
self, state: AssistantState, validation_error_message: Optional[str] = None
) -> list[BaseMessage]:
"""
Reconstruct the conversation for the generation. Take all previously generated questions, plans, and schemas, and return the history.
"""
Expand All @@ -317,22 +340,7 @@ def _reconstruct_conversation(self, state: AssistantState) -> list[BaseMessage]:
)
]

stack: list[LangchainHumanMessage] = []
human_messages: list[LangchainHumanMessage] = []
visualization_messages: list[VisualizationMessage] = []

for message in messages:
if isinstance(message, HumanMessage):
stack.append(LangchainHumanMessage(content=message.content))
elif isinstance(message, VisualizationMessage) and message.answer:
if stack:
human_messages += merge_message_runs(stack)
stack = []
visualization_messages.append(message)

if stack:
human_messages += merge_message_runs(stack)

human_messages, visualization_messages = filter_trends_conversation(messages)
first_ai_message = True

for human_message, ai_message in itertools.zip_longest(human_messages, visualization_messages):
Expand Down Expand Up @@ -364,6 +372,13 @@ def _reconstruct_conversation(self, state: AssistantState) -> list[BaseMessage]:
LangchainAssistantMessage(content=ai_message.answer.model_dump_json() if ai_message.answer else "")
)

if validation_error_message:
conversation.append(
HumanMessagePromptTemplate.from_template(trends_failover_prompt, template_format="mustache").format(
validation_error_message=validation_error_message
)
)

return conversation

@classmethod
Expand All @@ -382,4 +397,20 @@ class GenerateTrendsToolsNode(AssistantNode):
name = AssistantNodeName.GENERATE_TRENDS_TOOLS

def run(self, state: AssistantState, config: RunnableConfig):
return state
intermediate_steps = state.get("intermediate_steps", [])
if not intermediate_steps:
return state

action, _ = intermediate_steps[-1]
prompt = (
ChatPromptTemplate.from_template(trends_failover_output_prompt, template_format="mustache")
.format_messages(output=action.tool_input, exception_message=action.log)[0]
.content
)

return {
"intermediate_steps": [
*intermediate_steps[:-1],
(action, prompt),
]
}
24 changes: 24 additions & 0 deletions ee/hogai/trends/parsers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import json

from pydantic import ValidationError

from ee.hogai.trends.utils import GenerateTrendOutputModel


class PydanticOutputParserException(ValueError):
llm_output: str
"""Serialized LLM output."""
validation_message: str
"""Pydantic validation error message."""

def __init__(self, llm_output: str, validation_message: str):
super().__init__(llm_output)
self.llm_output = llm_output
self.validation_message = validation_message


def parse_generated_trends_output(output: dict) -> GenerateTrendOutputModel:
try:
return GenerateTrendOutputModel.model_validate(output)
except ValidationError as e:
raise PydanticOutputParserException(llm_output=json.dumps(output), validation_message=e.json(include_url=False))
Loading
Loading