Skip to content

Commit

Permalink
fix: fallback streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
skoob13 committed Oct 23, 2024
1 parent 9b643aa commit 7598454
Showing 1 changed file with 60 additions and 23 deletions.
83 changes: 60 additions & 23 deletions ee/hogai/assistant.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import cast
from typing import Any, Literal, TypedDict, TypeGuard, Union, cast

from langchain_core.messages import AIMessageChunk
from langfuse.callback import CallbackHandler
Expand All @@ -22,6 +22,33 @@
)


def is_state_update(update: list[Any]) -> TypeGuard[tuple[Literal["updates"], AssistantState]]:
"""
Update of the state.
"""
return len(update) == 2 and update[0] == "values"


def is_value_update(update: list[Any]) -> TypeGuard[tuple[Literal["values"], dict[AssistantNodeName, Any]]]:
"""
Transition between nodes.
"""
return len(update) == 2 and update[0] == "updates"


class LangGraphState(TypedDict):
langgraph_node: AssistantNodeName


def is_message_update(
update: list[Any],
) -> TypeGuard[tuple[Literal["messages"], tuple[Union[AssistantMessage, AIMessageChunk, Any], LangGraphState]]]:
"""
Streaming of messages. Returns a partial state.
"""
return len(update) == 2 and update[0] == "messages"


class Assistant:
_team: Team
_graph: StateGraph
Expand Down Expand Up @@ -55,30 +82,40 @@ def _compile_graph(self):

def stream(self, messages: list[AssistantMessage]) -> Generator[str, None, None]:
assistant_graph = self._compile_graph()
generator = assistant_graph.stream(
{"messages": messages},
config={"recursion_limit": 24, "callbacks": [langfuse_handler]},
stream_mode="messages",
)

chunks = AIMessageChunk(content="")
state: AssistantState = {"messages": messages, "intermediate_steps": None, "plan": None}

for message, state in generator:
if state["langgraph_node"] == AssistantNodeName.GENERATE_TRENDS:
if isinstance(message, AssistantMessage):
yield FrontendAssistantMessage(
type=message.type, content=message.content, payload=message.payload
).model_dump_json()
elif isinstance(message, AIMessageChunk):
message = cast(AIMessageChunk, message)
chunks += message # type: ignore
parsed_message = GenerateTrendsNode.parse_output(chunks.tool_calls[0]["args"])
if parsed_message:
for update in assistant_graph.stream(
state,
config={"recursion_limit": 24, "callbacks": [langfuse_handler]},
stream_mode=["messages", "values", "updates"],
):
if is_state_update(update):
_, new_state = update
state = new_state

elif is_value_update(update):
_, payload = update

# Reset chunks when schema validation fails.
if AssistantNodeName.GENERATE_TRENDS in payload:
chunks = AIMessageChunk(content="")

elif is_message_update(update):
message, langraph_state = update[1]
if langraph_state["langgraph_node"] == AssistantNodeName.GENERATE_TRENDS:
if isinstance(message, AssistantMessage):
yield FrontendAssistantMessage(
type="ai",
content=parsed_message.model_dump_json(),
payload=VisualizationMessagePayload(plan=""),
type=message.type, content=message.content, payload=message.payload
).model_dump_json()
elif state["langgraph_node"] == AssistantNodeName.GENERATE_TRENDS_TOOLS:
# Reset tool output parser when encountered a validation error
chunks = AIMessageChunk(content="")
elif isinstance(message, AIMessageChunk):
message = cast(AIMessageChunk, message)
chunks += message # type: ignore
parsed_message = GenerateTrendsNode.parse_output(chunks.tool_calls[0]["args"])
if parsed_message:
yield FrontendAssistantMessage(
type="ai",
content=parsed_message.model_dump_json(),
payload=VisualizationMessagePayload(plan=""),
).model_dump_json()

0 comments on commit 7598454

Please sign in to comment.