Skip to content

Commit

Permalink
Merge branch 'master' into split-temporal-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Twixes committed Oct 30, 2024
2 parents 39adbdc + d915d46 commit aac01bd
Show file tree
Hide file tree
Showing 38 changed files with 1,182 additions and 189 deletions.
3 changes: 2 additions & 1 deletion docker-compose.base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,8 @@ services:
environment:
- TEMPORAL_ADDRESS=temporal:7233
- TEMPORAL_CORS_ORIGINS=http://localhost:3000
image: temporalio/ui:2.10.3
- TEMPORAL_CSRF_COOKIE_INSECURE=true
image: temporalio/ui:2.31.2
ports:
- 8081:8080
temporal-django-worker:
Expand Down
59 changes: 45 additions & 14 deletions ee/hogai/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,18 @@
from langchain_core.messages import AIMessageChunk
from langfuse.callback import CallbackHandler
from langgraph.graph.state import StateGraph
from pydantic import BaseModel

from ee import settings
from ee.hogai.trends.nodes import CreateTrendsPlanNode, CreateTrendsPlanToolsNode, GenerateTrendsNode
from ee.hogai.trends.nodes import (
CreateTrendsPlanNode,
CreateTrendsPlanToolsNode,
GenerateTrendsNode,
GenerateTrendsToolsNode,
)
from ee.hogai.utils import AssistantNodeName, AssistantState, Conversation
from posthog.models.team.team import Team
from posthog.schema import VisualizationMessage
from posthog.schema import AssistantGenerationStatusEvent, AssistantGenerationStatusType, VisualizationMessage

if settings.LANGFUSE_PUBLIC_KEY:
langfuse_handler = CallbackHandler(
Expand Down Expand Up @@ -39,6 +45,13 @@ def is_message_update(
return len(update) == 2 and update[0] == "messages"


def is_state_update(update: list[Any]) -> TypeGuard[tuple[Literal["updates"], AssistantState]]:
"""
Update of the state.
"""
return len(update) == 2 and update[0] == "values"


class Assistant:
_team: Team
_graph: StateGraph
Expand All @@ -59,38 +72,56 @@ def _compile_graph(self):
generate_trends_node = GenerateTrendsNode(self._team)
builder.add_node(GenerateTrendsNode.name, generate_trends_node.run)

generate_trends_tools_node = GenerateTrendsToolsNode(self._team)
builder.add_node(GenerateTrendsToolsNode.name, generate_trends_tools_node.run)
builder.add_edge(GenerateTrendsToolsNode.name, GenerateTrendsNode.name)

builder.add_edge(AssistantNodeName.START, create_trends_plan_node.name)
builder.add_conditional_edges(create_trends_plan_node.name, create_trends_plan_node.router)
builder.add_conditional_edges(create_trends_plan_tools_node.name, create_trends_plan_tools_node.router)
builder.add_conditional_edges(GenerateTrendsNode.name, generate_trends_node.router)

return builder.compile()

def stream(self, conversation: Conversation) -> Generator[str, None, None]:
def stream(self, conversation: Conversation) -> Generator[BaseModel, None, None]:
assistant_graph = self._compile_graph()
callbacks = [langfuse_handler] if langfuse_handler else []
messages = [message.root for message in conversation.messages]

chunks = AIMessageChunk(content="")
state: AssistantState = {"messages": messages, "intermediate_steps": None, "plan": None}

generator = assistant_graph.stream(
{"messages": messages},
state,
config={"recursion_limit": 24, "callbacks": callbacks},
stream_mode=["messages", "updates"],
stream_mode=["messages", "values", "updates"],
)

chunks = AIMessageChunk(content="")

# Send a chunk to establish the connection avoiding the worker's timeout.
yield ""
yield AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.ACK)

for update in generator:
if is_value_update(update):
if is_state_update(update):
_, new_state = update
state = new_state

elif is_value_update(update):
_, state_update = update
if (
AssistantNodeName.GENERATE_TRENDS in state_update
and "messages" in state_update[AssistantNodeName.GENERATE_TRENDS]
):
message = cast(VisualizationMessage, state_update[AssistantNodeName.GENERATE_TRENDS]["messages"][0])
yield message.model_dump_json()

if AssistantNodeName.GENERATE_TRENDS in state_update:
# Reset chunks when schema validation fails.
chunks = AIMessageChunk(content="")

if "messages" in state_update[AssistantNodeName.GENERATE_TRENDS]:
message = cast(
VisualizationMessage, state_update[AssistantNodeName.GENERATE_TRENDS]["messages"][0]
)
yield message
elif state_update[AssistantNodeName.GENERATE_TRENDS].get("intermediate_steps", []):
yield AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.GENERATION_ERROR)

elif is_message_update(update):
langchain_message, langgraph_state = update[1]
if langgraph_state["langgraph_node"] == AssistantNodeName.GENERATE_TRENDS and isinstance(
Expand All @@ -101,4 +132,4 @@ def stream(self, conversation: Conversation) -> Generator[str, None, None]:
if parsed_message:
yield VisualizationMessage(
reasoning_steps=parsed_message.reasoning_steps, answer=parsed_message.answer
).model_dump_json()
)
Loading

0 comments on commit aac01bd

Please sign in to comment.