Skip to content

Commit

Permalink
fix: fallback streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
skoob13 committed Oct 25, 2024
1 parent 89a1579 commit 680ea43
Showing 1 changed file with 28 additions and 14 deletions.
42 changes: 28 additions & 14 deletions ee/hogai/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ def is_message_update(
return len(update) == 2 and update[0] == "messages"


def is_state_update(update: list[Any]) -> TypeGuard[tuple[Literal["updates"], AssistantState]]:
"""
Update of the state.
"""
return len(update) == 2 and update[0] == "values"


class Assistant:
_team: Team
_graph: StateGraph
Expand Down Expand Up @@ -80,23 +87,33 @@ def stream(self, conversation: Conversation) -> Generator[str, None, None]:
callbacks = [langfuse_handler] if langfuse_handler else []
messages = [message.root for message in conversation.messages]

chunks = AIMessageChunk(content="")
state: AssistantState = {"messages": messages, "intermediate_steps": None, "plan": None}

generator = assistant_graph.stream(
{"messages": messages},
state,
config={"recursion_limit": 24, "callbacks": callbacks},
stream_mode=["messages", "updates"],
stream_mode=["messages", "values", "updates"],
)

chunks = AIMessageChunk(content="")

for update in generator:
if is_value_update(update):
if is_state_update(update):
_, new_state = update
state = new_state

elif is_value_update(update):
_, state_update = update
if (
AssistantNodeName.GENERATE_TRENDS in state_update
and "messages" in state_update[AssistantNodeName.GENERATE_TRENDS]
):
message = cast(VisualizationMessage, state_update[AssistantNodeName.GENERATE_TRENDS]["messages"][0])
yield message.model_dump_json()

if AssistantNodeName.GENERATE_TRENDS in state_update:
# Reset chunks when schema validation fails.
chunks = AIMessageChunk(content="")

if "messages" in state_update[AssistantNodeName.GENERATE_TRENDS]:
message = cast(
VisualizationMessage, state_update[AssistantNodeName.GENERATE_TRENDS]["messages"][0]
)
yield message.model_dump_json()

elif is_message_update(update):
langchain_message, langgraph_state = update[1]
if langgraph_state["langgraph_node"] == AssistantNodeName.GENERATE_TRENDS and isinstance(
Expand All @@ -108,6 +125,3 @@ def stream(self, conversation: Conversation) -> Generator[str, None, None]:
yield VisualizationMessage(
reasoning_steps=parsed_message.reasoning_steps, answer=parsed_message.answer
).model_dump_json()
# elif state["langgraph_node"] == AssistantNodeName.GENERATE_TRENDS_TOOLS:
# # Reset tool output parser when encountered a validation error
# chunks = AIMessageChunk(content="")

0 comments on commit 680ea43

Please sign in to comment.