Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(product-assistant): better failover for the ReAct agent #25903

Merged
merged 8 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 49 additions & 27 deletions ee/hogai/trends/nodes.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import itertools
import xml.etree.ElementTree as ET
from functools import cached_property
from typing import Optional, Union, cast
from typing import Optional, cast

from langchain.agents.format_scratchpad import format_log_to_str
from langchain.agents.output_parsers import ReActJsonSingleInputOutputParser
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.exceptions import OutputParserException
from langchain_core.agents import AgentAction
from langchain_core.messages import AIMessage as LangchainAssistantMessage
from langchain_core.messages import BaseMessage, merge_message_runs
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
Expand All @@ -15,10 +13,20 @@
from pydantic import ValidationError

from ee.hogai.hardcoded_definitions import hardcoded_prop_defs
from ee.hogai.trends.parsers import PydanticOutputParserException, parse_generated_trends_output
from ee.hogai.trends.parsers import (
PydanticOutputParserException,
ReActParserException,
ReActParserMissingActionException,
parse_generated_trends_output,
parse_react_agent_output,
)
from ee.hogai.trends.prompts import (
react_definitions_prompt,
react_follow_up_prompt,
react_malformed_json_prompt,
react_missing_action_correction_prompt,
react_missing_action_prompt,
react_pydantic_validation_exception_prompt,
react_scratchpad_prompt,
react_system_prompt,
react_user_prompt,
Expand Down Expand Up @@ -80,40 +88,42 @@ def run(self, state: AssistantState, config: RunnableConfig):
)

toolkit = TrendsAgentToolkit(self._team)
output_parser = ReActJsonSingleInputOutputParser()
merger = merge_message_runs()

agent = prompt | merger | self._model | output_parser
agent = prompt | merger | self._model | parse_react_agent_output

try:
result = cast(
Union[AgentAction, AgentFinish],
AgentAction,
agent.invoke(
{
"tools": toolkit.render_text_description(),
"tool_names": ", ".join([t["name"] for t in toolkit.tools]),
"agent_scratchpad": format_log_to_str(
[(action, output) for action, output in intermediate_steps if output is not None]
),
"agent_scratchpad": self._get_agent_scratchpad(intermediate_steps),
},
config,
),
)
except OutputParserException as e:
text = str(e)
if e.send_to_llm:
observation = str(e.observation)
text = str(e.llm_output)
except ReActParserException as e:
if isinstance(e, ReActParserMissingActionException):
# When the agent doesn't output the "Action:" block, we need to correct the log and append the action block,
# so that it has a higher chance to recover.
corrected_log = str(
ChatPromptTemplate.from_template(react_missing_action_correction_prompt, template_format="mustache")
.format_messages(output=e.llm_output)[0]
.content
)
result = AgentAction(
"handle_incorrect_response",
react_missing_action_prompt,
corrected_log,
)
else:
observation = "Invalid or incomplete response. You must use the provided tools and output JSON to answer the user's question."
result = AgentAction("handle_incorrect_response", observation, text)

if isinstance(result, AgentFinish):
# Exceptional case
return {
"plan": result.log,
"intermediate_steps": None,
}
result = AgentAction(
"handle_incorrect_response",
react_malformed_json_prompt,
e.llm_output,
)

return {
"intermediate_steps": [*intermediate_steps, (result, None)],
Expand Down Expand Up @@ -205,6 +215,14 @@ def _reconstruct_conversation(self, state: AssistantState) -> list[BaseMessage]:

return conversation

def _get_agent_scratchpad(self, scratchpad: list[tuple[AgentAction, str | None]]) -> str:
actions = []
for action, observation in scratchpad:
if observation is None:
continue
actions.append((action, observation))
return format_log_to_str(actions)


class CreateTrendsPlanToolsNode(AssistantNode):
name = AssistantNodeName.CREATE_TRENDS_PLAN_TOOLS
Expand All @@ -217,8 +235,12 @@ def run(self, state: AssistantState, config: RunnableConfig):
try:
input = TrendsAgentToolModel.model_validate({"name": action.tool, "arguments": action.tool_input}).root
except ValidationError as e:
feedback = f"Invalid tool call. Pydantic exception: {e.errors(include_url=False)}"
return {"intermediate_steps": [*intermediate_steps, (action, feedback)]}
observation = (
ChatPromptTemplate.from_template(react_pydantic_validation_exception_prompt, template_format="mustache")
.format_messages(exception=e.errors(include_url=False))[0]
.content
)
return {"intermediate_steps": [*intermediate_steps[:-1], (action, observation)]}

# The plan has been found. Move to the generation.
if input.name == "final_answer":
Expand Down
56 changes: 56 additions & 0 deletions ee/hogai/trends/parsers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,66 @@
import json
import re

from langchain_core.agents import AgentAction
from langchain_core.messages import AIMessage as LangchainAIMessage
from pydantic import ValidationError

from ee.hogai.trends.utils import GenerateTrendOutputModel


class ReActParserException(ValueError):
llm_output: str

def __init__(self, llm_output: str):
super().__init__(llm_output)
self.llm_output = llm_output


class ReActParserMalformedJsonException(ReActParserException):
pass


class ReActParserMissingActionException(ReActParserException):
"""
The ReAct agent didn't output the "Action:" block.
"""

pass


ACTION_LOG_PREFIX = "Action:"


def parse_react_agent_output(message: LangchainAIMessage) -> AgentAction:
"""
A ReAct agent must output in this format:

Some thoughts...
Action:
```json
{"action": "action_name", "action_input": "action_input"}
```
"""
text = str(message.content)
if ACTION_LOG_PREFIX not in text:
raise ReActParserMissingActionException(text)
found = re.compile(r"^.*?`{3}(?:json)?\n?(.*?)`{3}.*?$", re.DOTALL).search(text)
if not found:
# JSON not found.
raise ReActParserMalformedJsonException(text)
try:
action = found.group(1).strip()
response = json.loads(action)
is_complete = "action" in response and "action_input" in response
except Exception:
# JSON is malformed or has a wrong type.
raise ReActParserMalformedJsonException(text)
if not is_complete:
# JSON does not contain an action.
raise ReActParserMalformedJsonException(text)
return AgentAction(response["action"], response.get("action_input", {}), text)


class PydanticOutputParserException(ValueError):
llm_output: str
"""Serialized LLM output."""
Expand Down
23 changes: 23 additions & 0 deletions ee/hogai/trends/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,29 @@
Improve the previously generated plan based on the feedback: {{feedback}}
"""

react_missing_action_prompt = """
Your previous answer didn't output the `Action:` block. You must always follow the format described in the system prompt.
"""

react_missing_action_correction_prompt = """
{{output}}
Action: I didn't output the `Action:` block.
"""

react_malformed_json_prompt = """
Your previous answer had a malformed JSON. You must return a correct JSON response containing the `action` and `action_input` fields.
"""

react_pydantic_validation_exception_prompt = """
The action input you previously provided didn't pass the validation and raised a Pydantic validation exception.

<pydantic_exception>
{{exception}}
</pydantic_exception>

You must fix the exception and try again.
"""

trends_system_prompt = """
You're a recognized head of product growth with the skills of a top-tier data engineer. Your task is to implement queries of trends insights for customers using a JSON schema. You will be given a plan describing series and breakdowns. Answer the user's questions as best you can.

Expand Down
76 changes: 75 additions & 1 deletion ee/hogai/trends/test/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@

from django.test import override_settings
from langchain_core.agents import AgentAction
from langchain_core.messages import AIMessage as LangchainAIMessage
from langchain_core.runnables import RunnableLambda

from ee.hogai.trends.nodes import CreateTrendsPlanNode, GenerateTrendsNode, GenerateTrendsToolsNode
from ee.hogai.trends.nodes import (
CreateTrendsPlanNode,
CreateTrendsPlanToolsNode,
GenerateTrendsNode,
GenerateTrendsToolsNode,
)
from ee.hogai.trends.utils import GenerateTrendOutputModel
from ee.hogai.utils import AssistantNodeName
from posthog.schema import (
Expand Down Expand Up @@ -115,6 +121,74 @@ def test_agent_preserves_low_count_events_for_smaller_teams(self):
self.assertIn("distinctevent", node._events_prompt)
self.assertIn("all events", node._events_prompt)

def test_agent_scratchpad(self):
node = CreateTrendsPlanNode(self.team)
scratchpad = [
(AgentAction(tool="test1", tool_input="input1", log="log1"), "test"),
(AgentAction(tool="test2", tool_input="input2", log="log2"), None),
(AgentAction(tool="test3", tool_input="input3", log="log3"), ""),
]
prompt = node._get_agent_scratchpad(scratchpad)
self.assertIn("log1", prompt)
self.assertIn("log3", prompt)

def test_agent_handles_output_without_action_block(self):
with patch(
"ee.hogai.trends.nodes.CreateTrendsPlanNode._model",
return_value=RunnableLambda(lambda _: LangchainAIMessage(content="I don't want to output an action.")),
):
node = CreateTrendsPlanNode(self.team)
state_update = node.run({"messages": [HumanMessage(content="Question")]}, {})
self.assertEqual(len(state_update["intermediate_steps"]), 1)
action, obs = state_update["intermediate_steps"][0]
self.assertIsNone(obs)
self.assertIn("I don't want to output an action.", action.log)
self.assertIn("Action:", action.log)
self.assertIn("Action:", action.tool_input)

def test_agent_handles_output_with_malformed_json(self):
with patch(
"ee.hogai.trends.nodes.CreateTrendsPlanNode._model",
return_value=RunnableLambda(lambda _: LangchainAIMessage(content="Thought.\nAction: abc")),
):
node = CreateTrendsPlanNode(self.team)
state_update = node.run({"messages": [HumanMessage(content="Question")]}, {})
self.assertEqual(len(state_update["intermediate_steps"]), 1)
action, obs = state_update["intermediate_steps"][0]
self.assertIsNone(obs)
self.assertIn("Thought.\nAction: abc", action.log)
self.assertIn("action", action.tool_input)
self.assertIn("action_input", action.tool_input)


@override_settings(IN_UNIT_TESTING=True)
class TestCreateTrendsPlanToolsNode(ClickhouseTestMixin, APIBaseTest):
def test_node_handles_action_name_validation_error(self):
state = {
"intermediate_steps": [(AgentAction(tool="does not exist", tool_input="input", log="log"), "test")],
"messages": [],
}
node = CreateTrendsPlanToolsNode(self.team)
state_update = node.run(state, {})
self.assertEqual(len(state_update["intermediate_steps"]), 1)
action, observation = state_update["intermediate_steps"][0]
self.assertIsNotNone(observation)
self.assertIn("<pydantic_exception>", observation)

def test_node_handles_action_input_validation_error(self):
state = {
"intermediate_steps": [
(AgentAction(tool="retrieve_entity_property_values", tool_input="input", log="log"), "test")
],
"messages": [],
}
node = CreateTrendsPlanToolsNode(self.team)
state_update = node.run(state, {})
self.assertEqual(len(state_update["intermediate_steps"]), 1)
action, observation = state_update["intermediate_steps"][0]
self.assertIsNotNone(observation)
self.assertIn("<pydantic_exception>", observation)


@override_settings(IN_UNIT_TESTING=True)
class TestGenerateTrendsNode(ClickhouseTestMixin, APIBaseTest):
Expand Down
78 changes: 78 additions & 0 deletions ee/hogai/trends/test/test_parsers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from langchain_core.messages import AIMessage as LangchainAIMessage

from ee.hogai.trends.parsers import (
ReActParserMalformedJsonException,
ReActParserMissingActionException,
parse_react_agent_output,
)
from posthog.test.base import BaseTest


class TestParsers(BaseTest):
def test_parse_react_agent_output(self):
res = parse_react_agent_output(
LangchainAIMessage(
content="""
Some thoughts...
Action:
```json
{"action": "action_name", "action_input": "action_input"}
```
"""
)
)
self.assertEqual(res.tool, "action_name")
self.assertEqual(res.tool_input, "action_input")

res = parse_react_agent_output(
LangchainAIMessage(
content="""
Some thoughts...
Action:
```
{"action": "tool", "action_input": {"key": "value"}}
```
"""
)
)
self.assertEqual(res.tool, "tool")
self.assertEqual(res.tool_input, {"key": "value"})

self.assertRaises(
ReActParserMissingActionException, parse_react_agent_output, LangchainAIMessage(content="Some thoughts...")
)
self.assertRaises(
ReActParserMalformedJsonException,
parse_react_agent_output,
LangchainAIMessage(content="Some thoughts...\nAction: abc"),
)
self.assertRaises(
ReActParserMalformedJsonException,
parse_react_agent_output,
LangchainAIMessage(content="Some thoughts...\nAction:"),
)
self.assertRaises(
ReActParserMalformedJsonException,
parse_react_agent_output,
LangchainAIMessage(content="Some thoughts...\nAction: {}"),
)
self.assertRaises(
ReActParserMalformedJsonException,
parse_react_agent_output,
LangchainAIMessage(content="Some thoughts...\nAction:\n```\n{}\n```"),
)
self.assertRaises(
ReActParserMalformedJsonException,
parse_react_agent_output,
LangchainAIMessage(content="Some thoughts...\nAction:\n```\n{not a json}\n```"),
)
self.assertRaises(
ReActParserMalformedJsonException,
parse_react_agent_output,
LangchainAIMessage(content='Some thoughts...\nAction:\n```\n{"action":"tool"}\n```'),
)
self.assertRaises(
ReActParserMalformedJsonException,
parse_react_agent_output,
LangchainAIMessage(content='Some thoughts...\nAction:\n```\n{"action_input":"input"}\n```'),
)
Loading
Loading