Skip to content

Commit

Permalink
feat(product-assistant): better failover for the ReAct agent (#25903)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
skoob13 and github-actions[bot] authored Oct 30, 2024
1 parent 124d166 commit 1fbc799
Show file tree
Hide file tree
Showing 5 changed files with 281 additions and 28 deletions.
76 changes: 49 additions & 27 deletions ee/hogai/trends/nodes.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import itertools
import xml.etree.ElementTree as ET
from functools import cached_property
from typing import Optional, Union, cast
from typing import Optional, cast

from langchain.agents.format_scratchpad import format_log_to_str
from langchain.agents.output_parsers import ReActJsonSingleInputOutputParser
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.exceptions import OutputParserException
from langchain_core.agents import AgentAction
from langchain_core.messages import AIMessage as LangchainAssistantMessage
from langchain_core.messages import BaseMessage, merge_message_runs
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
Expand All @@ -15,10 +13,20 @@
from pydantic import ValidationError

from ee.hogai.hardcoded_definitions import hardcoded_prop_defs
from ee.hogai.trends.parsers import PydanticOutputParserException, parse_generated_trends_output
from ee.hogai.trends.parsers import (
PydanticOutputParserException,
ReActParserException,
ReActParserMissingActionException,
parse_generated_trends_output,
parse_react_agent_output,
)
from ee.hogai.trends.prompts import (
react_definitions_prompt,
react_follow_up_prompt,
react_malformed_json_prompt,
react_missing_action_correction_prompt,
react_missing_action_prompt,
react_pydantic_validation_exception_prompt,
react_scratchpad_prompt,
react_system_prompt,
react_user_prompt,
Expand Down Expand Up @@ -80,40 +88,42 @@ def run(self, state: AssistantState, config: RunnableConfig):
)

toolkit = TrendsAgentToolkit(self._team)
output_parser = ReActJsonSingleInputOutputParser()
merger = merge_message_runs()

agent = prompt | merger | self._model | output_parser
agent = prompt | merger | self._model | parse_react_agent_output

try:
result = cast(
Union[AgentAction, AgentFinish],
AgentAction,
agent.invoke(
{
"tools": toolkit.render_text_description(),
"tool_names": ", ".join([t["name"] for t in toolkit.tools]),
"agent_scratchpad": format_log_to_str(
[(action, output) for action, output in intermediate_steps if output is not None]
),
"agent_scratchpad": self._get_agent_scratchpad(intermediate_steps),
},
config,
),
)
except OutputParserException as e:
text = str(e)
if e.send_to_llm:
observation = str(e.observation)
text = str(e.llm_output)
except ReActParserException as e:
if isinstance(e, ReActParserMissingActionException):
# When the agent doesn't output the "Action:" block, we need to correct the log and append the action block,
# so that it has a higher chance to recover.
corrected_log = str(
ChatPromptTemplate.from_template(react_missing_action_correction_prompt, template_format="mustache")
.format_messages(output=e.llm_output)[0]
.content
)
result = AgentAction(
"handle_incorrect_response",
react_missing_action_prompt,
corrected_log,
)
else:
observation = "Invalid or incomplete response. You must use the provided tools and output JSON to answer the user's question."
result = AgentAction("handle_incorrect_response", observation, text)

if isinstance(result, AgentFinish):
# Exceptional case
return {
"plan": result.log,
"intermediate_steps": None,
}
result = AgentAction(
"handle_incorrect_response",
react_malformed_json_prompt,
e.llm_output,
)

return {
"intermediate_steps": [*intermediate_steps, (result, None)],
Expand Down Expand Up @@ -205,6 +215,14 @@ def _reconstruct_conversation(self, state: AssistantState) -> list[BaseMessage]:

return conversation

def _get_agent_scratchpad(self, scratchpad: list[tuple[AgentAction, str | None]]) -> str:
actions = []
for action, observation in scratchpad:
if observation is None:
continue
actions.append((action, observation))
return format_log_to_str(actions)


class CreateTrendsPlanToolsNode(AssistantNode):
name = AssistantNodeName.CREATE_TRENDS_PLAN_TOOLS
Expand All @@ -217,8 +235,12 @@ def run(self, state: AssistantState, config: RunnableConfig):
try:
input = TrendsAgentToolModel.model_validate({"name": action.tool, "arguments": action.tool_input}).root
except ValidationError as e:
feedback = f"Invalid tool call. Pydantic exception: {e.errors(include_url=False)}"
return {"intermediate_steps": [*intermediate_steps, (action, feedback)]}
observation = (
ChatPromptTemplate.from_template(react_pydantic_validation_exception_prompt, template_format="mustache")
.format_messages(exception=e.errors(include_url=False))[0]
.content
)
return {"intermediate_steps": [*intermediate_steps[:-1], (action, observation)]}

# The plan has been found. Move to the generation.
if input.name == "final_answer":
Expand Down
56 changes: 56 additions & 0 deletions ee/hogai/trends/parsers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,66 @@
import json
import re

from langchain_core.agents import AgentAction
from langchain_core.messages import AIMessage as LangchainAIMessage
from pydantic import ValidationError

from ee.hogai.trends.utils import GenerateTrendOutputModel


class ReActParserException(ValueError):
llm_output: str

def __init__(self, llm_output: str):
super().__init__(llm_output)
self.llm_output = llm_output


class ReActParserMalformedJsonException(ReActParserException):
pass


class ReActParserMissingActionException(ReActParserException):
"""
The ReAct agent didn't output the "Action:" block.
"""

pass


ACTION_LOG_PREFIX = "Action:"


def parse_react_agent_output(message: LangchainAIMessage) -> AgentAction:
"""
A ReAct agent must output in this format:
Some thoughts...
Action:
```json
{"action": "action_name", "action_input": "action_input"}
```
"""
text = str(message.content)
if ACTION_LOG_PREFIX not in text:
raise ReActParserMissingActionException(text)
found = re.compile(r"^.*?`{3}(?:json)?\n?(.*?)`{3}.*?$", re.DOTALL).search(text)
if not found:
# JSON not found.
raise ReActParserMalformedJsonException(text)
try:
action = found.group(1).strip()
response = json.loads(action)
is_complete = "action" in response and "action_input" in response
except Exception:
# JSON is malformed or has a wrong type.
raise ReActParserMalformedJsonException(text)
if not is_complete:
# JSON does not contain an action.
raise ReActParserMalformedJsonException(text)
return AgentAction(response["action"], response.get("action_input", {}), text)


class PydanticOutputParserException(ValueError):
llm_output: str
"""Serialized LLM output."""
Expand Down
23 changes: 23 additions & 0 deletions ee/hogai/trends/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,29 @@
Improve the previously generated plan based on the feedback: {{feedback}}
"""

react_missing_action_prompt = """
Your previous answer didn't output the `Action:` block. You must always follow the format described in the system prompt.
"""

react_missing_action_correction_prompt = """
{{output}}
Action: I didn't output the `Action:` block.
"""

react_malformed_json_prompt = """
Your previous answer had a malformed JSON. You must return a correct JSON response containing the `action` and `action_input` fields.
"""

react_pydantic_validation_exception_prompt = """
The action input you previously provided didn't pass the validation and raised a Pydantic validation exception.
<pydantic_exception>
{{exception}}
</pydantic_exception>
You must fix the exception and try again.
"""

trends_system_prompt = """
You're a recognized head of product growth with the skills of a top-tier data engineer. Your task is to implement queries of trends insights for customers using a JSON schema. You will be given a plan describing series and breakdowns. Answer the user's questions as best you can.
Expand Down
76 changes: 75 additions & 1 deletion ee/hogai/trends/test/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@

from django.test import override_settings
from langchain_core.agents import AgentAction
from langchain_core.messages import AIMessage as LangchainAIMessage
from langchain_core.runnables import RunnableLambda

from ee.hogai.trends.nodes import CreateTrendsPlanNode, GenerateTrendsNode, GenerateTrendsToolsNode
from ee.hogai.trends.nodes import (
CreateTrendsPlanNode,
CreateTrendsPlanToolsNode,
GenerateTrendsNode,
GenerateTrendsToolsNode,
)
from ee.hogai.trends.utils import GenerateTrendOutputModel
from ee.hogai.utils import AssistantNodeName
from posthog.schema import (
Expand Down Expand Up @@ -115,6 +121,74 @@ def test_agent_preserves_low_count_events_for_smaller_teams(self):
self.assertIn("distinctevent", node._events_prompt)
self.assertIn("all events", node._events_prompt)

def test_agent_scratchpad(self):
node = CreateTrendsPlanNode(self.team)
scratchpad = [
(AgentAction(tool="test1", tool_input="input1", log="log1"), "test"),
(AgentAction(tool="test2", tool_input="input2", log="log2"), None),
(AgentAction(tool="test3", tool_input="input3", log="log3"), ""),
]
prompt = node._get_agent_scratchpad(scratchpad)
self.assertIn("log1", prompt)
self.assertIn("log3", prompt)

def test_agent_handles_output_without_action_block(self):
with patch(
"ee.hogai.trends.nodes.CreateTrendsPlanNode._model",
return_value=RunnableLambda(lambda _: LangchainAIMessage(content="I don't want to output an action.")),
):
node = CreateTrendsPlanNode(self.team)
state_update = node.run({"messages": [HumanMessage(content="Question")]}, {})
self.assertEqual(len(state_update["intermediate_steps"]), 1)
action, obs = state_update["intermediate_steps"][0]
self.assertIsNone(obs)
self.assertIn("I don't want to output an action.", action.log)
self.assertIn("Action:", action.log)
self.assertIn("Action:", action.tool_input)

def test_agent_handles_output_with_malformed_json(self):
with patch(
"ee.hogai.trends.nodes.CreateTrendsPlanNode._model",
return_value=RunnableLambda(lambda _: LangchainAIMessage(content="Thought.\nAction: abc")),
):
node = CreateTrendsPlanNode(self.team)
state_update = node.run({"messages": [HumanMessage(content="Question")]}, {})
self.assertEqual(len(state_update["intermediate_steps"]), 1)
action, obs = state_update["intermediate_steps"][0]
self.assertIsNone(obs)
self.assertIn("Thought.\nAction: abc", action.log)
self.assertIn("action", action.tool_input)
self.assertIn("action_input", action.tool_input)


@override_settings(IN_UNIT_TESTING=True)
class TestCreateTrendsPlanToolsNode(ClickhouseTestMixin, APIBaseTest):
def test_node_handles_action_name_validation_error(self):
state = {
"intermediate_steps": [(AgentAction(tool="does not exist", tool_input="input", log="log"), "test")],
"messages": [],
}
node = CreateTrendsPlanToolsNode(self.team)
state_update = node.run(state, {})
self.assertEqual(len(state_update["intermediate_steps"]), 1)
action, observation = state_update["intermediate_steps"][0]
self.assertIsNotNone(observation)
self.assertIn("<pydantic_exception>", observation)

def test_node_handles_action_input_validation_error(self):
state = {
"intermediate_steps": [
(AgentAction(tool="retrieve_entity_property_values", tool_input="input", log="log"), "test")
],
"messages": [],
}
node = CreateTrendsPlanToolsNode(self.team)
state_update = node.run(state, {})
self.assertEqual(len(state_update["intermediate_steps"]), 1)
action, observation = state_update["intermediate_steps"][0]
self.assertIsNotNone(observation)
self.assertIn("<pydantic_exception>", observation)


@override_settings(IN_UNIT_TESTING=True)
class TestGenerateTrendsNode(ClickhouseTestMixin, APIBaseTest):
Expand Down
78 changes: 78 additions & 0 deletions ee/hogai/trends/test/test_parsers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from langchain_core.messages import AIMessage as LangchainAIMessage

from ee.hogai.trends.parsers import (
ReActParserMalformedJsonException,
ReActParserMissingActionException,
parse_react_agent_output,
)
from posthog.test.base import BaseTest


class TestParsers(BaseTest):
def test_parse_react_agent_output(self):
res = parse_react_agent_output(
LangchainAIMessage(
content="""
Some thoughts...
Action:
```json
{"action": "action_name", "action_input": "action_input"}
```
"""
)
)
self.assertEqual(res.tool, "action_name")
self.assertEqual(res.tool_input, "action_input")

res = parse_react_agent_output(
LangchainAIMessage(
content="""
Some thoughts...
Action:
```
{"action": "tool", "action_input": {"key": "value"}}
```
"""
)
)
self.assertEqual(res.tool, "tool")
self.assertEqual(res.tool_input, {"key": "value"})

self.assertRaises(
ReActParserMissingActionException, parse_react_agent_output, LangchainAIMessage(content="Some thoughts...")
)
self.assertRaises(
ReActParserMalformedJsonException,
parse_react_agent_output,
LangchainAIMessage(content="Some thoughts...\nAction: abc"),
)
self.assertRaises(
ReActParserMalformedJsonException,
parse_react_agent_output,
LangchainAIMessage(content="Some thoughts...\nAction:"),
)
self.assertRaises(
ReActParserMalformedJsonException,
parse_react_agent_output,
LangchainAIMessage(content="Some thoughts...\nAction: {}"),
)
self.assertRaises(
ReActParserMalformedJsonException,
parse_react_agent_output,
LangchainAIMessage(content="Some thoughts...\nAction:\n```\n{}\n```"),
)
self.assertRaises(
ReActParserMalformedJsonException,
parse_react_agent_output,
LangchainAIMessage(content="Some thoughts...\nAction:\n```\n{not a json}\n```"),
)
self.assertRaises(
ReActParserMalformedJsonException,
parse_react_agent_output,
LangchainAIMessage(content='Some thoughts...\nAction:\n```\n{"action":"tool"}\n```'),
)
self.assertRaises(
ReActParserMalformedJsonException,
parse_react_agent_output,
LangchainAIMessage(content='Some thoughts...\nAction:\n```\n{"action_input":"input"}\n```'),
)

0 comments on commit 1fbc799

Please sign in to comment.