diff --git a/ee/hogai/assistant.py b/ee/hogai/assistant.py index 35bc23e302f3e..17a1c6341b667 100644 --- a/ee/hogai/assistant.py +++ b/ee/hogai/assistant.py @@ -148,13 +148,8 @@ def _init_or_update_state(self): if snapshot.next: saved_state = validate_state_update(snapshot.values) self._state = saved_state - if saved_state.intermediate_steps: - intermediate_steps = saved_state.intermediate_steps.copy() - intermediate_steps[-1] = (intermediate_steps[-1][0], self._latest_message.content) - self._graph.update_state( - config, - PartialAssistantState(messages=[self._latest_message], intermediate_steps=intermediate_steps), - ) + self._graph.update_state(config, PartialAssistantState(messages=[self._latest_message], resumed=True)) + return None initial_state = self._initial_state self._state = initial_state diff --git a/ee/hogai/taxonomy_agent/nodes.py b/ee/hogai/taxonomy_agent/nodes.py index bd26a7a93918f..92fe74ae55bcb 100644 --- a/ee/hogai/taxonomy_agent/nodes.py +++ b/ee/hogai/taxonomy_agent/nodes.py @@ -267,12 +267,18 @@ def _run_with_toolkit( ) if input.name == "ask_user_for_help": # The agent has requested help, so we interrupt the graph. - if not observation: + if not state.resumed: raise NodeInterrupt(input.arguments) # Feedback was provided. + last_message = state.messages[-1] + response = "" + if isinstance(last_message, HumanMessage): + response = last_message.content + return PartialAssistantState( - intermediate_steps=[*intermediate_steps[:-1], (action, observation)], + resumed=False, + intermediate_steps=[*intermediate_steps[:-1], (action, response)], ) output = "" diff --git a/ee/hogai/test/test_assistant.py b/ee/hogai/test/test_assistant.py index 4f4ad45170b99..48bb9b05d9b7e 100644 --- a/ee/hogai/test/test_assistant.py +++ b/ee/hogai/test/test_assistant.py @@ -252,7 +252,7 @@ def test_funnels_interrupt_when_asking_for_help(self): ) self._test_human_in_the_loop(graph) - def test_intermediate_steps_are_updated_after_feedback(self): + def test_messages_are_updated_after_feedback(self): with patch("ee.hogai.taxonomy_agent.nodes.TaxonomyAgentPlannerNode._model") as mock: graph = ( AssistantGraph(self.team) @@ -286,6 +286,7 @@ def test_intermediate_steps_are_updated_after_feedback(self): action, observation = snapshot.values["intermediate_steps"][0] self.assertEqual(action.tool, "ask_user_for_help") self.assertIsNone(observation) + self.assertNotIn("resumed", snapshot.values) self._run_assistant_graph(graph, conversation=self.conversation, message="It's straightforward") snapshot: StateSnapshot = graph.get_state(config) @@ -298,6 +299,44 @@ def test_intermediate_steps_are_updated_after_feedback(self): action, observation = snapshot.values["intermediate_steps"][1] self.assertEqual(action.tool, "ask_user_for_help") self.assertIsNone(observation) + self.assertFalse(snapshot.values["resumed"]) + + def test_resuming_uses_saved_state(self): + with patch("ee.hogai.taxonomy_agent.nodes.TaxonomyAgentPlannerNode._model") as mock: + graph = ( + AssistantGraph(self.team) + .add_edge(AssistantNodeName.START, AssistantNodeName.FUNNEL_PLANNER) + .add_funnel_planner(AssistantNodeName.END) + .compile() + ) + config: RunnableConfig = { + "configurable": { + "thread_id": self.conversation.id, + } + } + + # Interrupt the graph + message = """ + Thought: Let's ask for help. + Action: + ``` + { + "action": "ask_user_for_help", + "action_input": "Need help with this query" + } + ``` + """ + mock.return_value = RunnableLambda(lambda _: messages.AIMessage(content=message)) + + self._run_assistant_graph(graph, conversation=self.conversation) + state: StateSnapshot = graph.get_state(config).values + self.assertIn("start_id", state) + self.assertIsNotNone(state["start_id"]) + + self._run_assistant_graph(graph, conversation=self.conversation, message="It's straightforward") + state: StateSnapshot = graph.get_state(config).values + self.assertIn("start_id", state) + self.assertIsNotNone(state["start_id"]) def test_new_conversation_handles_serialized_conversation(self): graph = ( diff --git a/ee/hogai/utils/types.py b/ee/hogai/utils/types.py index 2df027b6f85af..917edb3d4987e 100644 --- a/ee/hogai/utils/types.py +++ b/ee/hogai/utils/types.py @@ -27,6 +27,7 @@ class _SharedAssistantState(BaseModel): The ID of the message from which the conversation started. """ plan: Optional[str] = Field(default=None) + resumed: Optional[bool] = Field(default=None) class AssistantState(_SharedAssistantState): @@ -34,7 +35,7 @@ class AssistantState(_SharedAssistantState): class PartialAssistantState(_SharedAssistantState): - messages: Optional[Annotated[Sequence[AssistantMessageUnion], operator.add]] = Field(default=None) + messages: Optional[Sequence[AssistantMessageUnion]] = Field(default=None) class AssistantNodeName(StrEnum):