diff --git a/ee/hogai/assistant.py b/ee/hogai/assistant.py index 91949403d14a2..fa753f645555b 100644 --- a/ee/hogai/assistant.py +++ b/ee/hogai/assistant.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import cast +from typing import Any, Literal, TypedDict, TypeGuard, Union, cast from langchain_core.messages import AIMessageChunk from langfuse.callback import CallbackHandler @@ -22,6 +22,33 @@ ) +def is_state_update(update: list[Any]) -> TypeGuard[tuple[Literal["updates"], AssistantState]]: + """ + Update of the state. + """ + return len(update) == 2 and update[0] == "values" + + +def is_value_update(update: list[Any]) -> TypeGuard[tuple[Literal["values"], dict[AssistantNodeName, Any]]]: + """ + Transition between nodes. + """ + return len(update) == 2 and update[0] == "updates" + + +class LangGraphState(TypedDict): + langgraph_node: AssistantNodeName + + +def is_message_update( + update: list[Any], +) -> TypeGuard[tuple[Literal["messages"], tuple[Union[AssistantMessage, AIMessageChunk, Any], LangGraphState]]]: + """ + Streaming of messages. Returns a partial state. + """ + return len(update) == 2 and update[0] == "messages" + + class Assistant: _team: Team _graph: StateGraph @@ -55,30 +82,40 @@ def _compile_graph(self): def stream(self, messages: list[AssistantMessage]) -> Generator[str, None, None]: assistant_graph = self._compile_graph() - generator = assistant_graph.stream( - {"messages": messages}, - config={"recursion_limit": 24, "callbacks": [langfuse_handler]}, - stream_mode="messages", - ) chunks = AIMessageChunk(content="") + state: AssistantState = {"messages": messages, "intermediate_steps": None, "plan": None} - for message, state in generator: - if state["langgraph_node"] == AssistantNodeName.GENERATE_TRENDS: - if isinstance(message, AssistantMessage): - yield FrontendAssistantMessage( - type=message.type, content=message.content, payload=message.payload - ).model_dump_json() - elif isinstance(message, AIMessageChunk): - message = cast(AIMessageChunk, message) - chunks += message # type: ignore - parsed_message = GenerateTrendsNode.parse_output(chunks.tool_calls[0]["args"]) - if parsed_message: + for update in assistant_graph.stream( + state, + config={"recursion_limit": 24, "callbacks": [langfuse_handler]}, + stream_mode=["messages", "values", "updates"], + ): + if is_state_update(update): + _, new_state = update + state = new_state + + elif is_value_update(update): + _, payload = update + + # Reset chunks when schema validation fails. + if AssistantNodeName.GENERATE_TRENDS in payload: + chunks = AIMessageChunk(content="") + + elif is_message_update(update): + message, langraph_state = update[1] + if langraph_state["langgraph_node"] == AssistantNodeName.GENERATE_TRENDS: + if isinstance(message, AssistantMessage): yield FrontendAssistantMessage( - type="ai", - content=parsed_message.model_dump_json(), - payload=VisualizationMessagePayload(plan=""), + type=message.type, content=message.content, payload=message.payload ).model_dump_json() - elif state["langgraph_node"] == AssistantNodeName.GENERATE_TRENDS_TOOLS: - # Reset tool output parser when encountered a validation error - chunks = AIMessageChunk(content="") + elif isinstance(message, AIMessageChunk): + message = cast(AIMessageChunk, message) + chunks += message # type: ignore + parsed_message = GenerateTrendsNode.parse_output(chunks.tool_calls[0]["args"]) + if parsed_message: + yield FrontendAssistantMessage( + type="ai", + content=parsed_message.model_dump_json(), + payload=VisualizationMessagePayload(plan=""), + ).model_dump_json()