Skip to content

Commit

Permalink
Merge branch 'master' into experiments/fix-updating-release-conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
danielbachhuber committed Dec 20, 2024
2 parents 42f4ee8 + 7d35d2b commit 73b36f3
Show file tree
Hide file tree
Showing 41 changed files with 761 additions and 1,592 deletions.
24 changes: 6 additions & 18 deletions ee/hogai/assistant.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import json
from collections.abc import AsyncGenerator, Generator, Iterator
from functools import partial
from collections.abc import Generator, Iterator
from typing import Any, Optional
from uuid import uuid4

from asgiref.sync import sync_to_async
from langchain_core.messages import AIMessageChunk
from langchain_core.runnables.config import RunnableConfig
from langfuse.callback import CallbackHandler
Expand All @@ -20,6 +18,7 @@
from ee.hogai.trends.nodes import (
TrendsGeneratorNode,
)
from ee.hogai.utils.asgi import SyncIterableToAsync
from ee.hogai.utils.state import (
GraphMessageUpdateTuple,
GraphTaskStartedUpdateTuple,
Expand Down Expand Up @@ -91,14 +90,8 @@ def stream(self):
return self._astream()
return self._stream()

async def _astream(self) -> AsyncGenerator[str, None]:
generator = self._stream()
while True:
try:
if message := await sync_to_async(partial(next, generator), thread_sensitive=False)():
yield message
except StopIteration:
break
def _astream(self):
return SyncIterableToAsync(self._stream())

def _stream(self) -> Generator[str, None, None]:
state = self._init_or_update_state()
Expand Down Expand Up @@ -155,13 +148,8 @@ def _init_or_update_state(self):
if snapshot.next:
saved_state = validate_state_update(snapshot.values)
self._state = saved_state
if saved_state.intermediate_steps:
intermediate_steps = saved_state.intermediate_steps.copy()
intermediate_steps[-1] = (intermediate_steps[-1][0], self._latest_message.content)
self._graph.update_state(
config,
PartialAssistantState(messages=[self._latest_message], intermediate_steps=intermediate_steps),
)
self._graph.update_state(config, PartialAssistantState(messages=[self._latest_message], resumed=True))

return None
initial_state = self._initial_state
self._state = initial_state
Expand Down
3 changes: 2 additions & 1 deletion ee/hogai/funnels/test/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def test_node_runs(self):
new_state,
PartialAssistantState(
messages=[VisualizationMessage(answer=self.schema, plan="Plan", id=new_state.messages[0].id)],
intermediate_steps=None,
intermediate_steps=[],
plan="",
),
)
22 changes: 12 additions & 10 deletions ee/hogai/schema_generator/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def _run_with_prompt(
content=f"Oops! It looks like I’m having trouble generating this {self.INSIGHT_NAME} insight. Could you please try again?"
)
],
intermediate_steps=None,
intermediate_steps=[],
plan="",
)

return PartialAssistantState(
Expand All @@ -106,16 +107,17 @@ def _run_with_prompt(
],
)

final_message = VisualizationMessage(
plan=generated_plan,
answer=message.query,
initiator=start_id,
id=str(uuid4()),
)

return PartialAssistantState(
messages=[
VisualizationMessage(
plan=generated_plan,
answer=message.query,
initiator=start_id,
id=str(uuid4()),
)
],
intermediate_steps=None,
messages=[final_message],
intermediate_steps=[],
plan="",
)

def router(self, state: AssistantState):
Expand Down
10 changes: 6 additions & 4 deletions ee/hogai/schema_generator/test/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def test_node_runs(self):
),
{},
)
self.assertIsNone(new_state.intermediate_steps)
self.assertEqual(new_state.intermediate_steps, [])
self.assertEqual(new_state.plan, "")
self.assertEqual(len(new_state.messages), 1)
self.assertEqual(new_state.messages[0].type, "ai/viz")
self.assertEqual(new_state.messages[0].answer, self.schema)
Expand Down Expand Up @@ -316,7 +317,7 @@ def test_node_leaves_failover(self):
),
{},
)
self.assertIsNone(new_state.intermediate_steps)
self.assertEqual(new_state.intermediate_steps, [])

new_state = node.run(
AssistantState(
Expand All @@ -328,7 +329,7 @@ def test_node_leaves_failover(self):
),
{},
)
self.assertIsNone(new_state.intermediate_steps)
self.assertEqual(new_state.intermediate_steps, [])

def test_node_leaves_failover_after_second_unsuccessful_attempt(self):
node = DummyGeneratorNode(self.team)
Expand All @@ -348,9 +349,10 @@ def test_node_leaves_failover_after_second_unsuccessful_attempt(self):
),
{},
)
self.assertIsNone(new_state.intermediate_steps)
self.assertEqual(new_state.intermediate_steps, [])
self.assertEqual(len(new_state.messages), 1)
self.assertIsInstance(new_state.messages[0], FailureMessage)
self.assertEqual(new_state.plan, "")

def test_agent_reconstructs_conversation_with_failover(self):
action = AgentAction(tool="fix", tool_input="validation error", log="exception")
Expand Down
8 changes: 8 additions & 0 deletions ee/hogai/summarizer/nodes.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import datetime
import json
from time import sleep
from uuid import uuid4

from django.conf import settings
from django.utils import timezone
from django.core.serializers.json import DjangoJSONEncoder
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig
Expand Down Expand Up @@ -76,11 +78,17 @@ def run(self, state: AssistantState, config: RunnableConfig) -> PartialAssistant

chain = summarization_prompt | self._model

utc_now = timezone.now().astimezone(datetime.UTC)
project_now = utc_now.astimezone(self._team.timezone_info)

message = chain.invoke(
{
"query_kind": viz_message.answer.kind,
"product_description": self._team.project.product_description,
"results": json.dumps(results_response["results"], cls=DjangoJSONEncoder),
"utc_datetime_display": utc_now.strftime("%Y-%m-%d %H:%M:%S"),
"project_datetime_display": project_now.strftime("%Y-%m-%d %H:%M:%S"),
"project_timezone": self._team.timezone_info.tzname(utc_now),
},
config,
)
Expand Down
24 changes: 18 additions & 6 deletions ee/hogai/summarizer/prompts.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,29 @@
SUMMARIZER_SYSTEM_PROMPT = """
Act as an expert product manager. Your task is to summarize query results in a a concise way.
Offer actionable feedback if possible. Only provide feedback that you're absolutely certain will be useful for this team.
Act as an expert product manager. Your task is to help the user build a successful product and business.
Also, you're a hedeghog named Max.
Offer actionable feedback if possible. Only provide suggestions you're certain will be useful for this team.
Acknowledge when more information would be needed. When query results are provided, note that the user can already see the chart.
Use Silicon Valley lingo. Be informal but get to the point immediately, without fluff - e.g. don't start with "alright, …".
NEVER use title case, even in headings. Our style is sentence case EVERYWHERE.
You can use Markdown for emphasis. Bullets can improve clarity of action points.
The product being analyzed is described as follows:
{{product_description}}"""

SUMMARIZER_INSTRUCTION_PROMPT = """
Here are the {{query_kind}} results for this question:
Here are results of the {{query_kind}} you created to answer my latest question:
```json
{{results}}
```
Answer my earlier question using the results above. Point out interesting trends or anomalies.
Take into account what you know about my product. If possible, offer actionable feedback, but avoid generic advice.
Limit yourself to a few sentences. The answer needs to be high-impact and relevant for me as a Silicon Valley engineer.
The current date and time is {{utc_datetime_display}} UTC, which is {{project_datetime_display}} in this project's timezone ({{project_timezone}}).
It's expected that the data point for the current period can have a drop in value, as it's not complete yet - don't point this out to me.
Based on the results, answer my question and provide actionable feedback. Avoid generic advice. Take into account what you know about the product.
The answer needs to be high-impact, no more than a few sentences.
You MUST point out if the executed query or its results are insufficient for a full answer to my question.
"""
10 changes: 8 additions & 2 deletions ee/hogai/taxonomy_agent/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,12 +267,18 @@ def _run_with_toolkit(
)
if input.name == "ask_user_for_help":
# The agent has requested help, so we interrupt the graph.
if not observation:
if not state.resumed:
raise NodeInterrupt(input.arguments)

# Feedback was provided.
last_message = state.messages[-1]
response = ""
if isinstance(last_message, HumanMessage):
response = last_message.content

return PartialAssistantState(
intermediate_steps=[*intermediate_steps[:-1], (action, observation)],
resumed=False,
intermediate_steps=[*intermediate_steps[:-1], (action, response)],
)

output = ""
Expand Down
97 changes: 93 additions & 4 deletions ee/hogai/test/test_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Optional, cast
from unittest.mock import patch

import pytest
from langchain_core import messages
from langchain_core.agents import AgentAction
from langchain_core.runnables import RunnableConfig, RunnableLambda
Expand All @@ -10,7 +11,7 @@
from pydantic import BaseModel

from ee.models.assistant import Conversation
from posthog.schema import AssistantMessage, HumanMessage, ReasoningMessage
from posthog.schema import AssistantMessage, FailureMessage, HumanMessage, ReasoningMessage
from posthog.test.base import NonAtomicBaseTest

from ..assistant import Assistant
Expand All @@ -24,6 +25,10 @@ def setUp(self):
super().setUp()
self.conversation = Conversation.objects.create(team=self.team, user=self.user)

def _parse_stringified_message(self, message: str) -> tuple[str, Any]:
event_line, data_line, *_ = cast(str, message).split("\n")
return (event_line.removeprefix("event: "), json.loads(data_line.removeprefix("data: ")))

def _run_assistant_graph(
self,
test_graph: Optional[CompiledStateGraph] = None,
Expand All @@ -44,8 +49,7 @@ def _run_assistant_graph(
# Capture and parse output of assistant.stream()
output: list[tuple[str, Any]] = []
for message in assistant.stream():
event_line, data_line, *_ = cast(str, message).split("\n")
output.append((event_line.removeprefix("event: "), json.loads(data_line.removeprefix("data: "))))
output.append(self._parse_stringified_message(message))
return output

def assertConversationEqual(self, output: list[tuple[str, Any]], expected_output: list[tuple[str, Any]]):
Expand Down Expand Up @@ -248,7 +252,7 @@ def test_funnels_interrupt_when_asking_for_help(self):
)
self._test_human_in_the_loop(graph)

def test_intermediate_steps_are_updated_after_feedback(self):
def test_messages_are_updated_after_feedback(self):
with patch("ee.hogai.taxonomy_agent.nodes.TaxonomyAgentPlannerNode._model") as mock:
graph = (
AssistantGraph(self.team)
Expand Down Expand Up @@ -282,6 +286,7 @@ def test_intermediate_steps_are_updated_after_feedback(self):
action, observation = snapshot.values["intermediate_steps"][0]
self.assertEqual(action.tool, "ask_user_for_help")
self.assertIsNone(observation)
self.assertNotIn("resumed", snapshot.values)

self._run_assistant_graph(graph, conversation=self.conversation, message="It's straightforward")
snapshot: StateSnapshot = graph.get_state(config)
Expand All @@ -294,6 +299,44 @@ def test_intermediate_steps_are_updated_after_feedback(self):
action, observation = snapshot.values["intermediate_steps"][1]
self.assertEqual(action.tool, "ask_user_for_help")
self.assertIsNone(observation)
self.assertFalse(snapshot.values["resumed"])

def test_resuming_uses_saved_state(self):
with patch("ee.hogai.taxonomy_agent.nodes.TaxonomyAgentPlannerNode._model") as mock:
graph = (
AssistantGraph(self.team)
.add_edge(AssistantNodeName.START, AssistantNodeName.FUNNEL_PLANNER)
.add_funnel_planner(AssistantNodeName.END)
.compile()
)
config: RunnableConfig = {
"configurable": {
"thread_id": self.conversation.id,
}
}

# Interrupt the graph
message = """
Thought: Let's ask for help.
Action:
```
{
"action": "ask_user_for_help",
"action_input": "Need help with this query"
}
```
"""
mock.return_value = RunnableLambda(lambda _: messages.AIMessage(content=message))

self._run_assistant_graph(graph, conversation=self.conversation)
state: StateSnapshot = graph.get_state(config).values
self.assertIn("start_id", state)
self.assertIsNotNone(state["start_id"])

self._run_assistant_graph(graph, conversation=self.conversation, message="It's straightforward")
state: StateSnapshot = graph.get_state(config).values
self.assertIn("start_id", state)
self.assertIsNotNone(state["start_id"])

def test_new_conversation_handles_serialized_conversation(self):
graph = (
Expand All @@ -319,3 +362,49 @@ def test_new_conversation_handles_serialized_conversation(self):
is_new_conversation=False,
)
self.assertNotEqual(output[0][0], "conversation")

@pytest.mark.asyncio
async def test_async_stream(self):
graph = (
AssistantGraph(self.team)
.add_node(AssistantNodeName.ROUTER, lambda _: {"messages": [AssistantMessage(content="bar")]})
.add_edge(AssistantNodeName.START, AssistantNodeName.ROUTER)
.add_edge(AssistantNodeName.ROUTER, AssistantNodeName.END)
.compile()
)
assistant = Assistant(self.team, self.conversation, HumanMessage(content="foo"))
assistant._graph = graph

expected_output = [
("message", HumanMessage(content="foo")),
("message", ReasoningMessage(content="Identifying type of analysis")),
("message", AssistantMessage(content="bar")),
]
actual_output = [self._parse_stringified_message(message) async for message in assistant._astream()]
self.assertConversationEqual(actual_output, expected_output)

@pytest.mark.asyncio
async def test_async_stream_handles_exceptions(self):
def node_handler(state):
raise ValueError()

graph = (
AssistantGraph(self.team)
.add_node(AssistantNodeName.ROUTER, node_handler)
.add_edge(AssistantNodeName.START, AssistantNodeName.ROUTER)
.add_edge(AssistantNodeName.ROUTER, AssistantNodeName.END)
.compile()
)
assistant = Assistant(self.team, self.conversation, HumanMessage(content="foo"))
assistant._graph = graph

expected_output = [
("message", HumanMessage(content="foo")),
("message", ReasoningMessage(content="Identifying type of analysis")),
("message", FailureMessage()),
]
actual_output = []
with self.assertRaises(ValueError):
async for message in assistant._astream():
actual_output.append(self._parse_stringified_message(message))
self.assertConversationEqual(actual_output, expected_output)
3 changes: 2 additions & 1 deletion ee/hogai/trends/test/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def test_node_runs(self):
new_state,
PartialAssistantState(
messages=[VisualizationMessage(answer=self.schema, plan="Plan", id=new_state.messages[0].id)],
intermediate_steps=None,
intermediate_steps=[],
plan="",
),
)
Loading

0 comments on commit 73b36f3

Please sign in to comment.