Skip to content

Commit

Permalink
feat(product-assistant): enhanced trends generation (#25484)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Michael Matloka <[email protected]>
  • Loading branch information
3 people authored Oct 25, 2024
1 parent 763214d commit 62cfd13
Show file tree
Hide file tree
Showing 36 changed files with 2,171 additions and 573 deletions.
101 changes: 101 additions & 0 deletions ee/hogai/assistant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from collections.abc import Generator
from typing import Any, Literal, TypedDict, TypeGuard, Union, cast

from langchain_core.messages import AIMessageChunk
from langfuse.callback import CallbackHandler
from langgraph.graph.state import StateGraph

from ee import settings
from ee.hogai.trends.nodes import CreateTrendsPlanNode, CreateTrendsPlanToolsNode, GenerateTrendsNode
from ee.hogai.utils import AssistantNodeName, AssistantState, Conversation
from posthog.models.team.team import Team
from posthog.schema import VisualizationMessage

if settings.LANGFUSE_PUBLIC_KEY:
langfuse_handler = CallbackHandler(
public_key=settings.LANGFUSE_PUBLIC_KEY, secret_key=settings.LANGFUSE_SECRET_KEY, host=settings.LANGFUSE_HOST
)
else:
langfuse_handler = None


def is_value_update(update: list[Any]) -> TypeGuard[tuple[Literal["values"], dict[AssistantNodeName, Any]]]:
"""
Transition between nodes.
"""
return len(update) == 2 and update[0] == "updates"


class LangGraphState(TypedDict):
langgraph_node: AssistantNodeName


def is_message_update(
update: list[Any],
) -> TypeGuard[tuple[Literal["messages"], tuple[Union[AIMessageChunk, Any], LangGraphState]]]:
"""
Streaming of messages. Returns a partial state.
"""
return len(update) == 2 and update[0] == "messages"


class Assistant:
_team: Team
_graph: StateGraph

def __init__(self, team: Team):
self._team = team
self._graph = StateGraph(AssistantState)

def _compile_graph(self):
builder = self._graph

create_trends_plan_node = CreateTrendsPlanNode(self._team)
builder.add_node(CreateTrendsPlanNode.name, create_trends_plan_node.run)

create_trends_plan_tools_node = CreateTrendsPlanToolsNode(self._team)
builder.add_node(CreateTrendsPlanToolsNode.name, create_trends_plan_tools_node.run)

generate_trends_node = GenerateTrendsNode(self._team)
builder.add_node(GenerateTrendsNode.name, generate_trends_node.run)

builder.add_edge(AssistantNodeName.START, create_trends_plan_node.name)
builder.add_conditional_edges(create_trends_plan_node.name, create_trends_plan_node.router)
builder.add_conditional_edges(create_trends_plan_tools_node.name, create_trends_plan_tools_node.router)
builder.add_conditional_edges(GenerateTrendsNode.name, generate_trends_node.router)

return builder.compile()

def stream(self, conversation: Conversation) -> Generator[str, None, None]:
assistant_graph = self._compile_graph()
callbacks = [langfuse_handler] if langfuse_handler else []
messages = [message.root for message in conversation.messages]

generator = assistant_graph.stream(
{"messages": messages},
config={"recursion_limit": 24, "callbacks": callbacks},
stream_mode=["messages", "updates"],
)

chunks = AIMessageChunk(content="")

for update in generator:
if is_value_update(update):
_, state_update = update
if (
AssistantNodeName.GENERATE_TRENDS in state_update
and "messages" in state_update[AssistantNodeName.GENERATE_TRENDS]
):
message = cast(VisualizationMessage, state_update[AssistantNodeName.GENERATE_TRENDS]["messages"][0])
yield message.model_dump_json()
elif is_message_update(update):
langchain_message, langgraph_state = update[1]
if langgraph_state["langgraph_node"] == AssistantNodeName.GENERATE_TRENDS and isinstance(
langchain_message, AIMessageChunk
):
chunks += langchain_message # type: ignore
parsed_message = GenerateTrendsNode.parse_output(chunks.tool_calls[0]["args"])
if parsed_message:
yield VisualizationMessage(
reasoning_steps=parsed_message.reasoning_steps, answer=parsed_message.answer
).model_dump_json()
55 changes: 0 additions & 55 deletions ee/hogai/generate_trends_agent.py

This file was deleted.

6 changes: 3 additions & 3 deletions ee/hogai/hardcoded_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
},
"$identify": {
"label": "Identify",
"description": "A user has been identified with properties",
"description": "Identifies an anonymous user. This event doesn't show how many users you have but rather how many users used an account.",
},
"$create_alias": {
"label": "Alias",
Expand Down Expand Up @@ -915,8 +915,8 @@
"session_properties": {
"$session_duration": {
"label": "Session duration",
"description": "The duration of the session being tracked. Learn more about how PostHog tracks sessions in our documentation.\n\nNote, if the duration is formatted as a single number (not 'HH:MM:SS'), it's in seconds.",
"examples": ["01:04:12"],
"description": "The duration of the session being tracked in seconds.",
"examples": ["30", "146", "2"],
"type": "Numeric",
},
"$start_timestamp": {
Expand Down
77 changes: 0 additions & 77 deletions ee/hogai/system_prompt.py

This file was deleted.

Loading

0 comments on commit 62cfd13

Please sign in to comment.