Skip to content

Commit

Permalink
feat: fail over
Browse files Browse the repository at this point in the history
  • Loading branch information
skoob13 committed Oct 23, 2024
1 parent 8d013b8 commit 9b643aa
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 23 deletions.
14 changes: 13 additions & 1 deletion ee/hogai/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
from langgraph.graph.state import StateGraph

from ee import settings
from ee.hogai.trends.nodes import CreateTrendsPlanNode, CreateTrendsPlanToolsNode, GenerateTrendsNode
from ee.hogai.trends.nodes import (
CreateTrendsPlanNode,
CreateTrendsPlanToolsNode,
GenerateTrendsNode,
GenerateTrendsToolsNode,
)
from ee.hogai.utils import AssistantMessage, AssistantNodeName, AssistantState
from posthog.models.team.team import Team
from posthog.schema import AssistantMessage as FrontendAssistantMessage
Expand Down Expand Up @@ -37,6 +42,10 @@ def _compile_graph(self):
generate_trends_node = GenerateTrendsNode(self._team)
builder.add_node(GenerateTrendsNode.name, generate_trends_node.run)

generate_trends_tools_node = GenerateTrendsToolsNode(self._team)
builder.add_node(GenerateTrendsToolsNode.name, generate_trends_tools_node.run)
builder.add_edge(GenerateTrendsToolsNode.name, GenerateTrendsNode.name)

builder.add_edge(AssistantNodeName.START, create_trends_plan_node.name)
builder.add_conditional_edges(create_trends_plan_node.name, create_trends_plan_node.router)
builder.add_conditional_edges(create_trends_plan_tools_node.name, create_trends_plan_tools_node.router)
Expand Down Expand Up @@ -70,3 +79,6 @@ def stream(self, messages: list[AssistantMessage]) -> Generator[str, None, None]
content=parsed_message.model_dump_json(),
payload=VisualizationMessagePayload(plan=""),
).model_dump_json()
elif state["langgraph_node"] == AssistantNodeName.GENERATE_TRENDS_TOOLS:
# Reset tool output parser when encountered a validation error
chunks = AIMessageChunk(content="")
50 changes: 29 additions & 21 deletions ee/hogai/trends/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import xml.etree.ElementTree as ET
from functools import cached_property
from typing import Union, cast
from typing import Optional, Union, cast

from langchain.agents.format_scratchpad import format_log_to_str
from langchain.agents.output_parsers import ReActJsonSingleInputOutputParser
Expand All @@ -22,6 +22,7 @@
react_scratchpad_prompt,
react_system_prompt,
react_user_prompt,
trends_failover_prompt,
trends_group_mapping_prompt,
trends_new_plan_prompt,
trends_plan_prompt,
Expand Down Expand Up @@ -254,11 +255,13 @@ def _group_mapping_prompt(self) -> str:
return ET.tostring(root, encoding="unicode")

def router(self, state: AssistantState):
if state.get("tool_argument") is not None:
if state.get("intermediate_steps", []):
return AssistantNodeName.GENERATE_TRENDS_TOOLS
return AssistantNodeName.END

def _reconstruct_conversation(self, state: AssistantState) -> list[BaseMessage]:
def _reconstruct_conversation(
self, state: AssistantState, validation_error_message: Optional[str] = None
) -> list[BaseMessage]:
"""
Reconstruct the conversation for the generation. Take all previously generated questions, plans, and schemas, and return the history.
"""
Expand Down Expand Up @@ -319,6 +322,13 @@ def _reconstruct_conversation(self, state: AssistantState) -> list[BaseMessage]:
if ai_message:
conversation.append(AIMessage(content=ai_message.content))

if validation_error_message:
conversation.append(
HumanMessagePromptTemplate.from_template(trends_failover_prompt, template_format="mustache").format(
exception_message=validation_error_message
)
)

return conversation

@classmethod
Expand All @@ -330,6 +340,8 @@ def parse_output(cls, output: dict):

def run(self, state: AssistantState, config: RunnableConfig):
generated_plan = state.get("plan", "")
intermediate_steps = state.get("intermediate_steps", [])
validation_error_message = intermediate_steps[-1][1] if intermediate_steps else None

llm = ChatOpenAI(model="gpt-4o", temperature=0.7, streaming=True).with_structured_output(
GenerateTrendTool().schema,
Expand All @@ -342,7 +354,7 @@ def run(self, state: AssistantState, config: RunnableConfig):
("system", trends_system_prompt),
],
template_format="mustache",
) + self._reconstruct_conversation(state)
) + self._reconstruct_conversation(state, validation_error_message=validation_error_message)
merger = merge_message_runs()

chain = (
Expand All @@ -357,23 +369,14 @@ def run(self, state: AssistantState, config: RunnableConfig):

try:
message = chain.invoke({}, config)
except OutputParserException:
# if e.send_to_llm:
# observation = str(e.observation)
# else:
# observation = "Invalid or incomplete response. You must use the provided tools and output JSON to answer the user's question."
# return {"tool_argument": observation}
except OutputParserException as e:
if e.send_to_llm:
observation = str(e.observation)
else:
observation = "Invalid or incomplete response. You must use the provided tools and output JSON to answer the user's question."

return {
"messages": [
AssistantMessage(
type="ai",
content=GenerateTrendOutputModel(
reasoning_steps=["Schema validation failed"]
).model_dump_json(),
payload=VisualizationMessagePayload(plan=generated_plan),
)
]
"intermediate_steps": [(AgentAction("handle_incorrect_response", observation, str(e)), None)],
}

return {
Expand All @@ -383,7 +386,8 @@ def run(self, state: AssistantState, config: RunnableConfig):
content=cast(GenerateTrendOutputModel, message).model_dump_json(),
payload=VisualizationMessagePayload(plan=generated_plan),
)
]
],
"intermediate_steps": None,
}


Expand All @@ -395,4 +399,8 @@ class GenerateTrendsToolsNode(AssistantNode):
name = AssistantNodeName.GENERATE_TRENDS_TOOLS

def run(self, state: AssistantState, config: RunnableConfig):
return state
intermediate_steps = state.get("intermediate_steps", [])
if not intermediate_steps:
return state
action, _ = intermediate_steps[-1]
return {"intermediate_steps": (action, action.log)}
10 changes: 10 additions & 0 deletions ee/hogai/trends/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,3 +271,13 @@
trends_question_prompt = """
Answer to this question: {{question}}
"""

trends_failover_prompt = """
The result of your previous generatin raised the Pydantic validation exception:
```
{{exception_message}}
```
Fix the error and return the correct response.
"""
1 change: 0 additions & 1 deletion ee/hogai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class AssistantState(TypedDict):
messages: Annotated[Sequence[AssistantMessage], add_messages]
intermediate_steps: Optional[list[tuple[AgentAction, Optional[str]]]]
plan: Optional[str]
tool_argument: Optional[str]


class AssistantNodeName(StrEnum):
Expand Down

0 comments on commit 9b643aa

Please sign in to comment.