Skip to content

Commit

Permalink
fix: parsing tools
Browse files Browse the repository at this point in the history
  • Loading branch information
skoob13 committed Oct 30, 2024
1 parent 5b5fbff commit 03af295
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
2 changes: 1 addition & 1 deletion ee/hogai/trends/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def run(self, state: AssistantState, config: RunnableConfig):
.format_messages(exception=e.errors(include_url=False))[0]
.content
)
return {"intermediate_steps": [*intermediate_steps, (action, observation)]}
return {"intermediate_steps": [*intermediate_steps[:-1], (action, observation)]}

# The plan has been found. Move to the generation.
if input.name == "final_answer":
Expand Down
31 changes: 30 additions & 1 deletion ee/hogai/trends/test/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from langchain_core.messages import AIMessage as LangchainAIMessage
from langchain_core.runnables import RunnableLambda

from ee.hogai.trends.nodes import CreateTrendsPlanNode, GenerateTrendsNode
from ee.hogai.trends.nodes import CreateTrendsPlanNode, CreateTrendsPlanToolsNode, GenerateTrendsNode
from ee.hogai.trends.parsers import AgentAction
from posthog.schema import AssistantMessage, ExperimentalAITrendsQuery, HumanMessage, VisualizationMessage
from posthog.test.base import APIBaseTest, ClickhouseTestMixin, _create_event, _create_person
Expand Down Expand Up @@ -131,6 +131,35 @@ def test_agent_handles_output_with_malformed_json(self):
self.assertIn("action_input", action.tool_input)


@override_settings(IN_UNIT_TESTING=True)
class TestCreateTrendsPlanToolsNode(ClickhouseTestMixin, APIBaseTest):
def test_node_handles_action_name_validation_error(self):
state = {
"intermediate_steps": [(AgentAction(tool="does not exist", tool_input="input", log="log"), "test")],
"messages": [],
}
node = CreateTrendsPlanToolsNode(self.team)
state_update = node.run(state, {})
self.assertEqual(len(state_update["intermediate_steps"]), 1)
action, observation = state_update["intermediate_steps"][0]
self.assertIsNotNone(observation)
self.assertIn("<pydantic_exception>", observation)

def test_node_handles_action_input_validation_error(self):
state = {
"intermediate_steps": [
(AgentAction(tool="retrieve_entity_property_values", tool_input="input", log="log"), "test")
],
"messages": [],
}
node = CreateTrendsPlanToolsNode(self.team)
state_update = node.run(state, {})
self.assertEqual(len(state_update["intermediate_steps"]), 1)
action, observation = state_update["intermediate_steps"][0]
self.assertIsNotNone(observation)
self.assertIn("<pydantic_exception>", observation)


@override_settings(IN_UNIT_TESTING=True)
class TestGenerateTrendsNode(ClickhouseTestMixin, APIBaseTest):
def setUp(self):
Expand Down

0 comments on commit 03af295

Please sign in to comment.