diff --git a/ee/hogai/trends/nodes.py b/ee/hogai/trends/nodes.py index 2697d5dc3f3a8..6ca6ed8b50d50 100644 --- a/ee/hogai/trends/nodes.py +++ b/ee/hogai/trends/nodes.py @@ -228,7 +228,7 @@ def run(self, state: AssistantState, config: RunnableConfig): .format_messages(exception=e.errors(include_url=False))[0] .content ) - return {"intermediate_steps": [*intermediate_steps, (action, observation)]} + return {"intermediate_steps": [*intermediate_steps[:-1], (action, observation)]} # The plan has been found. Move to the generation. if input.name == "final_answer": diff --git a/ee/hogai/trends/test/test_nodes.py b/ee/hogai/trends/test/test_nodes.py index 184a81ef071e8..1371213cfeab8 100644 --- a/ee/hogai/trends/test/test_nodes.py +++ b/ee/hogai/trends/test/test_nodes.py @@ -4,7 +4,7 @@ from langchain_core.messages import AIMessage as LangchainAIMessage from langchain_core.runnables import RunnableLambda -from ee.hogai.trends.nodes import CreateTrendsPlanNode, GenerateTrendsNode +from ee.hogai.trends.nodes import CreateTrendsPlanNode, CreateTrendsPlanToolsNode, GenerateTrendsNode from ee.hogai.trends.parsers import AgentAction from posthog.schema import AssistantMessage, ExperimentalAITrendsQuery, HumanMessage, VisualizationMessage from posthog.test.base import APIBaseTest, ClickhouseTestMixin, _create_event, _create_person @@ -131,6 +131,35 @@ def test_agent_handles_output_with_malformed_json(self): self.assertIn("action_input", action.tool_input) +@override_settings(IN_UNIT_TESTING=True) +class TestCreateTrendsPlanToolsNode(ClickhouseTestMixin, APIBaseTest): + def test_node_handles_action_name_validation_error(self): + state = { + "intermediate_steps": [(AgentAction(tool="does not exist", tool_input="input", log="log"), "test")], + "messages": [], + } + node = CreateTrendsPlanToolsNode(self.team) + state_update = node.run(state, {}) + self.assertEqual(len(state_update["intermediate_steps"]), 1) + action, observation = state_update["intermediate_steps"][0] + self.assertIsNotNone(observation) + self.assertIn("", observation) + + def test_node_handles_action_input_validation_error(self): + state = { + "intermediate_steps": [ + (AgentAction(tool="retrieve_entity_property_values", tool_input="input", log="log"), "test") + ], + "messages": [], + } + node = CreateTrendsPlanToolsNode(self.team) + state_update = node.run(state, {}) + self.assertEqual(len(state_update["intermediate_steps"]), 1) + action, observation = state_update["intermediate_steps"][0] + self.assertIsNotNone(observation) + self.assertIn("", observation) + + @override_settings(IN_UNIT_TESTING=True) class TestGenerateTrendsNode(ClickhouseTestMixin, APIBaseTest): def setUp(self):