Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/PostHog/posthog into feat…
Browse files Browse the repository at this point in the history
…ure-management-backend-setup
  • Loading branch information
havenbarnes committed Dec 20, 2024
2 parents 5c2ba52 + e5a8530 commit ec86c2f
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 11 deletions.
9 changes: 2 additions & 7 deletions ee/hogai/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,8 @@ def _init_or_update_state(self):
if snapshot.next:
saved_state = validate_state_update(snapshot.values)
self._state = saved_state
if saved_state.intermediate_steps:
intermediate_steps = saved_state.intermediate_steps.copy()
intermediate_steps[-1] = (intermediate_steps[-1][0], self._latest_message.content)
self._graph.update_state(
config,
PartialAssistantState(messages=[self._latest_message], intermediate_steps=intermediate_steps),
)
self._graph.update_state(config, PartialAssistantState(messages=[self._latest_message], resumed=True))

return None
initial_state = self._initial_state
self._state = initial_state
Expand Down
10 changes: 8 additions & 2 deletions ee/hogai/taxonomy_agent/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,12 +267,18 @@ def _run_with_toolkit(
)
if input.name == "ask_user_for_help":
# The agent has requested help, so we interrupt the graph.
if not observation:
if not state.resumed:
raise NodeInterrupt(input.arguments)

# Feedback was provided.
last_message = state.messages[-1]
response = ""
if isinstance(last_message, HumanMessage):
response = last_message.content

return PartialAssistantState(
intermediate_steps=[*intermediate_steps[:-1], (action, observation)],
resumed=False,
intermediate_steps=[*intermediate_steps[:-1], (action, response)],
)

output = ""
Expand Down
41 changes: 40 additions & 1 deletion ee/hogai/test/test_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def test_funnels_interrupt_when_asking_for_help(self):
)
self._test_human_in_the_loop(graph)

def test_intermediate_steps_are_updated_after_feedback(self):
def test_messages_are_updated_after_feedback(self):
with patch("ee.hogai.taxonomy_agent.nodes.TaxonomyAgentPlannerNode._model") as mock:
graph = (
AssistantGraph(self.team)
Expand Down Expand Up @@ -286,6 +286,7 @@ def test_intermediate_steps_are_updated_after_feedback(self):
action, observation = snapshot.values["intermediate_steps"][0]
self.assertEqual(action.tool, "ask_user_for_help")
self.assertIsNone(observation)
self.assertNotIn("resumed", snapshot.values)

self._run_assistant_graph(graph, conversation=self.conversation, message="It's straightforward")
snapshot: StateSnapshot = graph.get_state(config)
Expand All @@ -298,6 +299,44 @@ def test_intermediate_steps_are_updated_after_feedback(self):
action, observation = snapshot.values["intermediate_steps"][1]
self.assertEqual(action.tool, "ask_user_for_help")
self.assertIsNone(observation)
self.assertFalse(snapshot.values["resumed"])

def test_resuming_uses_saved_state(self):
with patch("ee.hogai.taxonomy_agent.nodes.TaxonomyAgentPlannerNode._model") as mock:
graph = (
AssistantGraph(self.team)
.add_edge(AssistantNodeName.START, AssistantNodeName.FUNNEL_PLANNER)
.add_funnel_planner(AssistantNodeName.END)
.compile()
)
config: RunnableConfig = {
"configurable": {
"thread_id": self.conversation.id,
}
}

# Interrupt the graph
message = """
Thought: Let's ask for help.
Action:
```
{
"action": "ask_user_for_help",
"action_input": "Need help with this query"
}
```
"""
mock.return_value = RunnableLambda(lambda _: messages.AIMessage(content=message))

self._run_assistant_graph(graph, conversation=self.conversation)
state: StateSnapshot = graph.get_state(config).values
self.assertIn("start_id", state)
self.assertIsNotNone(state["start_id"])

self._run_assistant_graph(graph, conversation=self.conversation, message="It's straightforward")
state: StateSnapshot = graph.get_state(config).values
self.assertIn("start_id", state)
self.assertIsNotNone(state["start_id"])

def test_new_conversation_handles_serialized_conversation(self):
graph = (
Expand Down
3 changes: 2 additions & 1 deletion ee/hogai/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@ class _SharedAssistantState(BaseModel):
The ID of the message from which the conversation started.
"""
plan: Optional[str] = Field(default=None)
resumed: Optional[bool] = Field(default=None)


class AssistantState(_SharedAssistantState):
messages: Annotated[Sequence[AssistantMessageUnion], operator.add]


class PartialAssistantState(_SharedAssistantState):
messages: Optional[Annotated[Sequence[AssistantMessageUnion], operator.add]] = Field(default=None)
messages: Optional[Sequence[AssistantMessageUnion]] = Field(default=None)


class AssistantNodeName(StrEnum):
Expand Down

0 comments on commit ec86c2f

Please sign in to comment.