From 680ea43616ad237ffc408cd18d4c2923aec913ec Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Wed, 23 Oct 2024 19:06:15 +0200 Subject: [PATCH] fix: fallback streaming --- ee/hogai/assistant.py | 42 ++++++++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/ee/hogai/assistant.py b/ee/hogai/assistant.py index 539baedf8bd1d..bd54f01ed56a8 100644 --- a/ee/hogai/assistant.py +++ b/ee/hogai/assistant.py @@ -44,6 +44,13 @@ def is_message_update( return len(update) == 2 and update[0] == "messages" +def is_state_update(update: list[Any]) -> TypeGuard[tuple[Literal["updates"], AssistantState]]: + """ + Update of the state. + """ + return len(update) == 2 and update[0] == "values" + + class Assistant: _team: Team _graph: StateGraph @@ -80,23 +87,33 @@ def stream(self, conversation: Conversation) -> Generator[str, None, None]: callbacks = [langfuse_handler] if langfuse_handler else [] messages = [message.root for message in conversation.messages] + chunks = AIMessageChunk(content="") + state: AssistantState = {"messages": messages, "intermediate_steps": None, "plan": None} + generator = assistant_graph.stream( - {"messages": messages}, + state, config={"recursion_limit": 24, "callbacks": callbacks}, - stream_mode=["messages", "updates"], + stream_mode=["messages", "values", "updates"], ) - chunks = AIMessageChunk(content="") - for update in generator: - if is_value_update(update): + if is_state_update(update): + _, new_state = update + state = new_state + + elif is_value_update(update): _, state_update = update - if ( - AssistantNodeName.GENERATE_TRENDS in state_update - and "messages" in state_update[AssistantNodeName.GENERATE_TRENDS] - ): - message = cast(VisualizationMessage, state_update[AssistantNodeName.GENERATE_TRENDS]["messages"][0]) - yield message.model_dump_json() + + if AssistantNodeName.GENERATE_TRENDS in state_update: + # Reset chunks when schema validation fails. + chunks = AIMessageChunk(content="") + + if "messages" in state_update[AssistantNodeName.GENERATE_TRENDS]: + message = cast( + VisualizationMessage, state_update[AssistantNodeName.GENERATE_TRENDS]["messages"][0] + ) + yield message.model_dump_json() + elif is_message_update(update): langchain_message, langgraph_state = update[1] if langgraph_state["langgraph_node"] == AssistantNodeName.GENERATE_TRENDS and isinstance( @@ -108,6 +125,3 @@ def stream(self, conversation: Conversation) -> Generator[str, None, None]: yield VisualizationMessage( reasoning_steps=parsed_message.reasoning_steps, answer=parsed_message.answer ).model_dump_json() - # elif state["langgraph_node"] == AssistantNodeName.GENERATE_TRENDS_TOOLS: - # # Reset tool output parser when encountered a validation error - # chunks = AIMessageChunk(content="")