Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(product-assistant): correct exception to stop the async iterator #27101

Merged
merged 1 commit into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 4 additions & 11 deletions ee/hogai/assistant.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import json
from collections.abc import AsyncGenerator, Generator, Iterator
from functools import partial
from collections.abc import Generator, Iterator
from typing import Any, Optional
from uuid import uuid4

from asgiref.sync import sync_to_async
from langchain_core.messages import AIMessageChunk
from langchain_core.runnables.config import RunnableConfig
from langfuse.callback import CallbackHandler
Expand All @@ -20,6 +18,7 @@
from ee.hogai.trends.nodes import (
TrendsGeneratorNode,
)
from ee.hogai.utils.asgi import SyncIterableToAsync
from ee.hogai.utils.state import (
GraphMessageUpdateTuple,
GraphTaskStartedUpdateTuple,
Expand Down Expand Up @@ -91,14 +90,8 @@ def stream(self):
return self._astream()
return self._stream()

async def _astream(self) -> AsyncGenerator[str, None]:
generator = self._stream()
while True:
try:
if message := await sync_to_async(partial(next, generator), thread_sensitive=False)():
yield message
except StopIteration:
break
def _astream(self):
return SyncIterableToAsync(self._stream())

def _stream(self) -> Generator[str, None, None]:
state = self._init_or_update_state()
Expand Down
56 changes: 53 additions & 3 deletions ee/hogai/test/test_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Optional, cast
from unittest.mock import patch

import pytest
from langchain_core import messages
from langchain_core.agents import AgentAction
from langchain_core.runnables import RunnableConfig, RunnableLambda
Expand All @@ -10,7 +11,7 @@
from pydantic import BaseModel

from ee.models.assistant import Conversation
from posthog.schema import AssistantMessage, HumanMessage, ReasoningMessage
from posthog.schema import AssistantMessage, FailureMessage, HumanMessage, ReasoningMessage
from posthog.test.base import NonAtomicBaseTest

from ..assistant import Assistant
Expand All @@ -24,6 +25,10 @@ def setUp(self):
super().setUp()
self.conversation = Conversation.objects.create(team=self.team, user=self.user)

def _parse_stringified_message(self, message: str) -> tuple[str, Any]:
event_line, data_line, *_ = cast(str, message).split("\n")
return (event_line.removeprefix("event: "), json.loads(data_line.removeprefix("data: ")))

def _run_assistant_graph(
self,
test_graph: Optional[CompiledStateGraph] = None,
Expand All @@ -44,8 +49,7 @@ def _run_assistant_graph(
# Capture and parse output of assistant.stream()
output: list[tuple[str, Any]] = []
for message in assistant.stream():
event_line, data_line, *_ = cast(str, message).split("\n")
output.append((event_line.removeprefix("event: "), json.loads(data_line.removeprefix("data: "))))
output.append(self._parse_stringified_message(message))
return output

def assertConversationEqual(self, output: list[tuple[str, Any]], expected_output: list[tuple[str, Any]]):
Expand Down Expand Up @@ -319,3 +323,49 @@ def test_new_conversation_handles_serialized_conversation(self):
is_new_conversation=False,
)
self.assertNotEqual(output[0][0], "conversation")

@pytest.mark.asyncio
async def test_async_stream(self):
graph = (
AssistantGraph(self.team)
.add_node(AssistantNodeName.ROUTER, lambda _: {"messages": [AssistantMessage(content="bar")]})
.add_edge(AssistantNodeName.START, AssistantNodeName.ROUTER)
.add_edge(AssistantNodeName.ROUTER, AssistantNodeName.END)
.compile()
)
assistant = Assistant(self.team, self.conversation, HumanMessage(content="foo"))
assistant._graph = graph

expected_output = [
("message", HumanMessage(content="foo")),
("message", ReasoningMessage(content="Identifying type of analysis")),
("message", AssistantMessage(content="bar")),
]
actual_output = [self._parse_stringified_message(message) async for message in assistant._astream()]
self.assertConversationEqual(actual_output, expected_output)

@pytest.mark.asyncio
async def test_async_stream_handles_exceptions(self):
def node_handler(state):
raise ValueError()

graph = (
AssistantGraph(self.team)
.add_node(AssistantNodeName.ROUTER, node_handler)
.add_edge(AssistantNodeName.START, AssistantNodeName.ROUTER)
.add_edge(AssistantNodeName.ROUTER, AssistantNodeName.END)
.compile()
)
assistant = Assistant(self.team, self.conversation, HumanMessage(content="foo"))
assistant._graph = graph

expected_output = [
("message", HumanMessage(content="foo")),
("message", ReasoningMessage(content="Identifying type of analysis")),
("message", FailureMessage()),
]
actual_output = []
with self.assertRaises(ValueError):
async for message in assistant._astream():
actual_output.append(self._parse_stringified_message(message))
self.assertConversationEqual(actual_output, expected_output)
34 changes: 34 additions & 0 deletions ee/hogai/utils/asgi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from collections.abc import AsyncIterator, Callable, Iterable, Iterator
from typing import TypeVar

from asgiref.sync import sync_to_async

T = TypeVar("T")


class SyncIterableToAsync(AsyncIterator[T]):
def __init__(self, iterable: Iterable[T]) -> None:
self._iterable: Iterable[T] = iterable
# async versions of the `next` and `iter` functions
self.next_async: Callable = sync_to_async(self.next, thread_sensitive=False)
self.iter_async: Callable = sync_to_async(iter, thread_sensitive=False)
self.sync_iterator: Iterator[T] | None = None

def __aiter__(self) -> AsyncIterator[T]:
return self

async def __anext__(self) -> T:
if self.sync_iterator is None:
self.sync_iterator = await self.iter_async(self._iterable)
return await self.next_async(self.sync_iterator)

@staticmethod
def next(it: Iterator[T]) -> T:
"""
asyncio expects `StopAsyncIteration` in place of `StopIteration`,
so here's a modified in-built `next` function that can handle this.
"""
try:
return next(it)
except StopIteration:
raise StopAsyncIteration
Loading