diff --git a/ee/hogai/assistant.py b/ee/hogai/assistant.py index 99d1d06b5b7e7..91949403d14a2 100644 --- a/ee/hogai/assistant.py +++ b/ee/hogai/assistant.py @@ -6,7 +6,12 @@ from langgraph.graph.state import StateGraph from ee import settings -from ee.hogai.trends.nodes import CreateTrendsPlanNode, CreateTrendsPlanToolsNode, GenerateTrendsNode +from ee.hogai.trends.nodes import ( + CreateTrendsPlanNode, + CreateTrendsPlanToolsNode, + GenerateTrendsNode, + GenerateTrendsToolsNode, +) from ee.hogai.utils import AssistantMessage, AssistantNodeName, AssistantState from posthog.models.team.team import Team from posthog.schema import AssistantMessage as FrontendAssistantMessage @@ -37,6 +42,10 @@ def _compile_graph(self): generate_trends_node = GenerateTrendsNode(self._team) builder.add_node(GenerateTrendsNode.name, generate_trends_node.run) + generate_trends_tools_node = GenerateTrendsToolsNode(self._team) + builder.add_node(GenerateTrendsToolsNode.name, generate_trends_tools_node.run) + builder.add_edge(GenerateTrendsToolsNode.name, GenerateTrendsNode.name) + builder.add_edge(AssistantNodeName.START, create_trends_plan_node.name) builder.add_conditional_edges(create_trends_plan_node.name, create_trends_plan_node.router) builder.add_conditional_edges(create_trends_plan_tools_node.name, create_trends_plan_tools_node.router) @@ -70,3 +79,6 @@ def stream(self, messages: list[AssistantMessage]) -> Generator[str, None, None] content=parsed_message.model_dump_json(), payload=VisualizationMessagePayload(plan=""), ).model_dump_json() + elif state["langgraph_node"] == AssistantNodeName.GENERATE_TRENDS_TOOLS: + # Reset tool output parser when encountered a validation error + chunks = AIMessageChunk(content="") diff --git a/ee/hogai/trends/nodes.py b/ee/hogai/trends/nodes.py index cb7438c513526..6800869c7a407 100644 --- a/ee/hogai/trends/nodes.py +++ b/ee/hogai/trends/nodes.py @@ -2,7 +2,7 @@ import json import xml.etree.ElementTree as ET from functools import cached_property -from typing import Union, cast +from typing import Optional, Union, cast from langchain.agents.format_scratchpad import format_log_to_str from langchain.agents.output_parsers import ReActJsonSingleInputOutputParser @@ -22,6 +22,7 @@ react_scratchpad_prompt, react_system_prompt, react_user_prompt, + trends_failover_prompt, trends_group_mapping_prompt, trends_new_plan_prompt, trends_plan_prompt, @@ -254,11 +255,13 @@ def _group_mapping_prompt(self) -> str: return ET.tostring(root, encoding="unicode") def router(self, state: AssistantState): - if state.get("tool_argument") is not None: + if state.get("intermediate_steps", []): return AssistantNodeName.GENERATE_TRENDS_TOOLS return AssistantNodeName.END - def _reconstruct_conversation(self, state: AssistantState) -> list[BaseMessage]: + def _reconstruct_conversation( + self, state: AssistantState, validation_error_message: Optional[str] = None + ) -> list[BaseMessage]: """ Reconstruct the conversation for the generation. Take all previously generated questions, plans, and schemas, and return the history. """ @@ -319,6 +322,13 @@ def _reconstruct_conversation(self, state: AssistantState) -> list[BaseMessage]: if ai_message: conversation.append(AIMessage(content=ai_message.content)) + if validation_error_message: + conversation.append( + HumanMessagePromptTemplate.from_template(trends_failover_prompt, template_format="mustache").format( + exception_message=validation_error_message + ) + ) + return conversation @classmethod @@ -330,6 +340,8 @@ def parse_output(cls, output: dict): def run(self, state: AssistantState, config: RunnableConfig): generated_plan = state.get("plan", "") + intermediate_steps = state.get("intermediate_steps", []) + validation_error_message = intermediate_steps[-1][1] if intermediate_steps else None llm = ChatOpenAI(model="gpt-4o", temperature=0.7, streaming=True).with_structured_output( GenerateTrendTool().schema, @@ -342,7 +354,7 @@ def run(self, state: AssistantState, config: RunnableConfig): ("system", trends_system_prompt), ], template_format="mustache", - ) + self._reconstruct_conversation(state) + ) + self._reconstruct_conversation(state, validation_error_message=validation_error_message) merger = merge_message_runs() chain = ( @@ -357,23 +369,14 @@ def run(self, state: AssistantState, config: RunnableConfig): try: message = chain.invoke({}, config) - except OutputParserException: - # if e.send_to_llm: - # observation = str(e.observation) - # else: - # observation = "Invalid or incomplete response. You must use the provided tools and output JSON to answer the user's question." - # return {"tool_argument": observation} + except OutputParserException as e: + if e.send_to_llm: + observation = str(e.observation) + else: + observation = "Invalid or incomplete response. You must use the provided tools and output JSON to answer the user's question." return { - "messages": [ - AssistantMessage( - type="ai", - content=GenerateTrendOutputModel( - reasoning_steps=["Schema validation failed"] - ).model_dump_json(), - payload=VisualizationMessagePayload(plan=generated_plan), - ) - ] + "intermediate_steps": [(AgentAction("handle_incorrect_response", observation, str(e)), None)], } return { @@ -383,7 +386,8 @@ def run(self, state: AssistantState, config: RunnableConfig): content=cast(GenerateTrendOutputModel, message).model_dump_json(), payload=VisualizationMessagePayload(plan=generated_plan), ) - ] + ], + "intermediate_steps": None, } @@ -395,4 +399,8 @@ class GenerateTrendsToolsNode(AssistantNode): name = AssistantNodeName.GENERATE_TRENDS_TOOLS def run(self, state: AssistantState, config: RunnableConfig): - return state + intermediate_steps = state.get("intermediate_steps", []) + if not intermediate_steps: + return state + action, _ = intermediate_steps[-1] + return {"intermediate_steps": (action, action.log)} diff --git a/ee/hogai/trends/prompts.py b/ee/hogai/trends/prompts.py index a6292f38ccb09..c72ca02398be1 100644 --- a/ee/hogai/trends/prompts.py +++ b/ee/hogai/trends/prompts.py @@ -271,3 +271,13 @@ trends_question_prompt = """ Answer to this question: {{question}} """ + +trends_failover_prompt = """ +The result of your previous generatin raised the Pydantic validation exception: + +``` +{{exception_message}} +``` + +Fix the error and return the correct response. +""" diff --git a/ee/hogai/utils.py b/ee/hogai/utils.py index 39580e25db2bb..786be25b6ca02 100644 --- a/ee/hogai/utils.py +++ b/ee/hogai/utils.py @@ -27,7 +27,6 @@ class AssistantState(TypedDict): messages: Annotated[Sequence[AssistantMessage], add_messages] intermediate_steps: Optional[list[tuple[AgentAction, Optional[str]]]] plan: Optional[str] - tool_argument: Optional[str] class AssistantNodeName(StrEnum):