diff --git a/ee/api/conversation.py b/ee/api/conversation.py new file mode 100644 index 0000000000000..70e314b94039f --- /dev/null +++ b/ee/api/conversation.py @@ -0,0 +1,69 @@ +from typing import cast + +from django.http import StreamingHttpResponse +from pydantic import ValidationError +from rest_framework import serializers +from rest_framework.renderers import BaseRenderer +from rest_framework.request import Request +from rest_framework.viewsets import GenericViewSet + +from ee.hogai.assistant import Assistant +from ee.models.assistant import Conversation +from posthog.api.routing import TeamAndOrgViewSetMixin +from posthog.models.user import User +from posthog.rate_limit import AIBurstRateThrottle, AISustainedRateThrottle +from posthog.schema import HumanMessage + + +class MessageSerializer(serializers.Serializer): + content = serializers.CharField(required=True, max_length=1000) + conversation = serializers.UUIDField(required=False) + + def validate(self, data): + try: + message = HumanMessage(content=data["content"]) + data["message"] = message + except ValidationError: + raise serializers.ValidationError("Invalid message content.") + return data + + +class ServerSentEventRenderer(BaseRenderer): + media_type = "text/event-stream" + format = "txt" + + def render(self, data, accepted_media_type=None, renderer_context=None): + return data + + +class ConversationViewSet(TeamAndOrgViewSetMixin, GenericViewSet): + scope_object = "INTERNAL" + serializer_class = MessageSerializer + renderer_classes = [ServerSentEventRenderer] + queryset = Conversation.objects.all() + lookup_url_kwarg = "conversation" + + def safely_get_queryset(self, queryset): + # Only allow access to conversations created by the current user + return queryset.filter(user=self.request.user) + + def get_throttles(self): + return [AIBurstRateThrottle(), AISustainedRateThrottle()] + + def create(self, request: Request, *args, **kwargs): + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + conversation_id = serializer.validated_data.get("conversation") + if conversation_id: + self.kwargs[self.lookup_url_kwarg] = conversation_id + conversation = self.get_object() + else: + conversation = self.get_queryset().create(user=request.user, team=self.team) + assistant = Assistant( + self.team, + conversation, + serializer.validated_data["message"], + user=cast(User, request.user), + is_new_conversation=not conversation_id, + ) + return StreamingHttpResponse(assistant.stream(), content_type=ServerSentEventRenderer.media_type) diff --git a/ee/api/test/test_conversation.py b/ee/api/test/test_conversation.py new file mode 100644 index 0000000000000..6eb466876dc01 --- /dev/null +++ b/ee/api/test/test_conversation.py @@ -0,0 +1,157 @@ +from unittest.mock import patch + +from rest_framework import status + +from ee.hogai.assistant import Assistant +from ee.models.assistant import Conversation +from posthog.models.team.team import Team +from posthog.models.user import User +from posthog.test.base import APIBaseTest + + +class TestConversation(APIBaseTest): + def setUp(self): + super().setUp() + self.other_team = Team.objects.create(organization=self.organization, name="other team") + self.other_user = User.objects.create_and_join( + organization=self.organization, + email="other@posthog.com", + password="password", + first_name="Other", + ) + + def _get_streaming_content(self, response): + return b"".join(response.streaming_content) + + def test_create_conversation(self): + with patch.object(Assistant, "_stream", return_value=["test response"]) as stream_mock: + response = self.client.post( + f"/api/environments/{self.team.id}/conversations/", + {"content": "test query"}, + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(self._get_streaming_content(response), b"test response") + self.assertEqual(Conversation.objects.count(), 1) + conversation: Conversation = Conversation.objects.first() + self.assertEqual(conversation.user, self.user) + self.assertEqual(conversation.team, self.team) + stream_mock.assert_called_once() + + def test_add_message_to_existing_conversation(self): + with patch.object(Assistant, "_stream", return_value=["test response"]) as stream_mock: + conversation = Conversation.objects.create(user=self.user, team=self.team) + response = self.client.post( + f"/api/environments/{self.team.id}/conversations/", + { + "conversation": str(conversation.id), + "content": "test query", + }, + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(self._get_streaming_content(response), b"test response") + self.assertEqual(Conversation.objects.count(), 1) + stream_mock.assert_called_once() + + def test_cant_access_other_users_conversation(self): + conversation = Conversation.objects.create(user=self.other_user, team=self.team) + + self.client.force_login(self.user) + response = self.client.post( + f"/api/environments/{self.team.id}/conversations/", + {"conversation": conversation.id, "content": "test query"}, + ) + + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_cant_access_other_teams_conversation(self): + conversation = Conversation.objects.create(user=self.user, team=self.other_team) + response = self.client.post( + f"/api/environments/{self.team.id}/conversations/", + {"conversation": conversation.id, "content": "test query"}, + ) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_invalid_message_format(self): + response = self.client.post("/api/environments/@current/conversations/") + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_rate_limit_burst(self): + # Create multiple requests to trigger burst rate limit + with patch.object(Assistant, "_stream", return_value=["test response"]): + for _ in range(11): # Assuming burst limit is less than this + response = self.client.post( + f"/api/environments/{self.team.id}/conversations/", + {"content": "test query"}, + ) + self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) + + def test_empty_content(self): + response = self.client.post( + f"/api/environments/{self.team.id}/conversations/", + {"content": ""}, + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_content_too_long(self): + response = self.client.post( + f"/api/environments/{self.team.id}/conversations/", + {"content": "x" * 1001}, # Very long message + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_invalid_conversation_id(self): + response = self.client.post( + f"/api/environments/{self.team.id}/conversations/", + { + "conversation": "not-a-valid-uuid", + "content": "test query", + }, + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_nonexistent_conversation(self): + response = self.client.post( + f"/api/environments/{self.team.id}/conversations/", + { + "conversation": "12345678-1234-5678-1234-567812345678", + "content": "test query", + }, + ) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_deleted_conversation(self): + # Create and then delete a conversation + conversation = Conversation.objects.create(user=self.user, team=self.team) + conversation_id = conversation.id + conversation.delete() + + response = self.client.post( + f"/api/environments/{self.team.id}/conversations/", + { + "conversation": str(conversation_id), + "content": "test query", + }, + ) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_unauthenticated_request(self): + self.client.logout() + response = self.client.post( + f"/api/environments/{self.team.id}/conversations/", + {"content": "test query"}, + ) + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + + def test_streaming_error_handling(self): + def raise_error(): + yield "some content" + raise Exception("Streaming error") + + with patch.object(Assistant, "_stream", side_effect=raise_error): + response = self.client.post( + f"/api/environments/{self.team.id}/conversations/", + {"content": "test query"}, + ) + with self.assertRaises(Exception) as context: + b"".join(response.streaming_content) + self.assertTrue("Streaming error" in str(context.exception)) diff --git a/ee/hogai/assistant.py b/ee/hogai/assistant.py index 77b1c2c050008..3a296ba9ce7d6 100644 --- a/ee/hogai/assistant.py +++ b/ee/hogai/assistant.py @@ -1,9 +1,12 @@ +import json from collections.abc import AsyncGenerator, Generator, Iterator from functools import partial -from typing import Any, Literal, Optional, TypedDict, TypeGuard, Union +from typing import Any, Optional +from uuid import uuid4 from asgiref.sync import sync_to_async from langchain_core.messages import AIMessageChunk +from langchain_core.runnables.config import RunnableConfig from langfuse.callback import CallbackHandler from langgraph.graph.state import CompiledStateGraph from pydantic import BaseModel @@ -17,7 +20,19 @@ from ee.hogai.trends.nodes import ( TrendsGeneratorNode, ) -from ee.hogai.utils import AssistantNodeName, AssistantState, Conversation +from ee.hogai.utils.state import ( + GraphMessageUpdateTuple, + GraphTaskStartedUpdateTuple, + GraphValueUpdateTuple, + is_message_update, + is_state_update, + is_task_started_update, + is_value_update, + validate_state_update, + validate_value_update, +) +from ee.hogai.utils.types import AssistantNodeName, AssistantState, PartialAssistantState +from ee.models import Conversation from posthog.event_usage import report_user_action from posthog.models import Team, User from posthog.schema import ( @@ -40,42 +55,6 @@ langfuse_handler = None -def is_value_update(update: list[Any]) -> TypeGuard[tuple[Literal["values"], dict[AssistantNodeName, AssistantState]]]: - """ - Transition between nodes. - """ - return len(update) == 2 and update[0] == "updates" - - -class LangGraphState(TypedDict): - langgraph_node: AssistantNodeName - - -def is_message_update( - update: list[Any], -) -> TypeGuard[tuple[Literal["messages"], tuple[Union[AIMessageChunk, Any], LangGraphState]]]: - """ - Streaming of messages. Returns a partial state. - """ - return len(update) == 2 and update[0] == "messages" - - -def is_state_update(update: list[Any]) -> TypeGuard[tuple[Literal["updates"], AssistantState]]: - """ - Update of the state. - """ - return len(update) == 2 and update[0] == "values" - - -def is_task_started_update( - update: list[Any], -) -> TypeGuard[tuple[Literal["messages"], tuple[Union[AIMessageChunk, Any], LangGraphState]]]: - """ - Streaming of messages. Returns a partial state. - """ - return len(update) == 2 and update[0] == "debug" and update[1]["type"] == "task" - - VISUALIZATION_NODES: dict[AssistantNodeName, type[SchemaGeneratorNode]] = { AssistantNodeName.TRENDS_GENERATOR: TrendsGeneratorNode, AssistantNodeName.FUNNEL_GENERATOR: FunnelGeneratorNode, @@ -87,13 +66,25 @@ class Assistant: _graph: CompiledStateGraph _user: Optional[User] _conversation: Conversation + _latest_message: HumanMessage + _state: Optional[AssistantState] - def __init__(self, team: Team, conversation: Conversation, user: Optional[User] = None): + def __init__( + self, + team: Team, + conversation: Conversation, + new_message: HumanMessage, + user: Optional[User] = None, + is_new_conversation: bool = False, + ): self._team = team self._user = user self._conversation = conversation + self._latest_message = new_message.model_copy(deep=True, update={"id": str(uuid4())}) + self._is_new_conversation = is_new_conversation self._graph = AssistantGraph(team).compile_full_graph() self._chunks = AIMessageChunk(content="") + self._state = None def stream(self): if SERVER_GATEWAY_INTERFACE == "ASGI": @@ -110,15 +101,19 @@ async def _astream(self) -> AsyncGenerator[str, None]: break def _stream(self) -> Generator[str, None, None]: - callbacks = [langfuse_handler] if langfuse_handler else [] + state = self._init_or_update_state() + config = self._get_config() + generator: Iterator[Any] = self._graph.stream( - self._initial_state, - config={"recursion_limit": 24, "callbacks": callbacks}, - stream_mode=["messages", "values", "updates", "debug"], + state, config=config, stream_mode=["messages", "values", "updates", "debug"] ) - # Send a chunk to establish the connection avoiding the worker's timeout. - yield self._serialize_message(AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.ACK)) + # Assign the conversation id to the client. + if self._is_new_conversation: + yield self._serialize_conversation() + + # Send the last message with the initialized id. + yield self._serialize_message(self._latest_message) try: last_viz_message = None @@ -127,7 +122,15 @@ def _stream(self) -> Generator[str, None, None]: if isinstance(message, VisualizationMessage): last_viz_message = message yield self._serialize_message(message) - self._report_conversation(last_viz_message) + + # Check if the assistant has requested help. + state = self._graph.get_state(config) + if state.next: + yield self._serialize_message( + AssistantMessage(content=state.tasks[0].interrupts[0].value, id=str(uuid4())) + ) + else: + self._report_conversation_state(last_viz_message) except: # This is an unhandled error, so we just stop further generation at this point yield self._serialize_message(FailureMessage()) @@ -135,8 +138,34 @@ def _stream(self) -> Generator[str, None, None]: @property def _initial_state(self) -> AssistantState: - messages = [message.root for message in self._conversation.messages] - return {"messages": messages, "intermediate_steps": None, "plan": None} + return AssistantState(messages=[self._latest_message], start_id=self._latest_message.id) + + def _get_config(self) -> RunnableConfig: + callbacks = [langfuse_handler] if langfuse_handler else [] + config: RunnableConfig = { + "recursion_limit": 24, + "callbacks": callbacks, + "configurable": {"thread_id": self._conversation.id}, + } + return config + + def _init_or_update_state(self): + config = self._get_config() + snapshot = self._graph.get_state(config) + if snapshot.next: + saved_state = validate_state_update(snapshot.values) + self._state = saved_state + if saved_state.intermediate_steps: + intermediate_steps = saved_state.intermediate_steps.copy() + intermediate_steps[-1] = (intermediate_steps[-1][0], self._latest_message.content) + self._graph.update_state( + config, + PartialAssistantState(messages=[self._latest_message], intermediate_steps=intermediate_steps), + ) + return None + initial_state = self._initial_state + self._state = initial_state + return initial_state def _node_to_reasoning_message( self, node_name: AssistantNodeName, input: AssistantState @@ -152,7 +181,7 @@ def _node_to_reasoning_message( ): substeps: list[str] = [] if input: - if intermediate_steps := input.get("intermediate_steps"): + if intermediate_steps := input.intermediate_steps: for action, _ in intermediate_steps: match action.tool: case "retrieve_event_properties": @@ -178,42 +207,65 @@ def _node_to_reasoning_message( return None def _process_update(self, update: Any) -> BaseModel | None: - if is_value_update(update): - _, state_update = update + if is_state_update(update): + _, new_state = update + self._state = validate_state_update(new_state) + elif is_value_update(update) and (new_message := self._process_value_update(update)): + return new_message + elif is_message_update(update) and (new_message := self._process_message_update(update)): + return new_message + elif is_task_started_update(update) and (new_message := self._process_task_started_update(update)): + return new_message + return None - if AssistantNodeName.ROUTER in state_update and "messages" in state_update[AssistantNodeName.ROUTER]: - return state_update[AssistantNodeName.ROUTER]["messages"][0] - elif intersected_nodes := state_update.keys() & VISUALIZATION_NODES.keys(): - # Reset chunks when schema validation fails. - self._chunks = AIMessageChunk(content="") + def _process_value_update(self, update: GraphValueUpdateTuple) -> BaseModel | None: + _, maybe_state_update = update + state_update = validate_value_update(maybe_state_update) + + if node_val := state_update.get(AssistantNodeName.ROUTER): + if isinstance(node_val, PartialAssistantState) and node_val.messages: + return node_val.messages[0] + elif intersected_nodes := state_update.keys() & VISUALIZATION_NODES.keys(): + # Reset chunks when schema validation fails. + self._chunks = AIMessageChunk(content="") - node_name = intersected_nodes.pop() - if "messages" in state_update[node_name]: - return state_update[node_name]["messages"][0] - elif state_update[node_name].get("intermediate_steps", []): - return AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.GENERATION_ERROR) - elif AssistantNodeName.SUMMARIZER in state_update: + node_name = intersected_nodes.pop() + node_val = state_update[node_name] + if not isinstance(node_val, PartialAssistantState): + return None + if node_val.messages: + return node_val.messages[0] + elif node_val.intermediate_steps: + return AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.GENERATION_ERROR) + elif node_val := state_update.get(AssistantNodeName.SUMMARIZER): + if isinstance(node_val, PartialAssistantState) and node_val.messages: self._chunks = AIMessageChunk(content="") - return state_update[AssistantNodeName.SUMMARIZER]["messages"][0] - elif is_message_update(update): - langchain_message, langgraph_state = update[1] - if isinstance(langchain_message, AIMessageChunk): - if langgraph_state["langgraph_node"] in VISUALIZATION_NODES.keys(): - self._chunks += langchain_message # type: ignore - parsed_message = VISUALIZATION_NODES[langgraph_state["langgraph_node"]].parse_output( - self._chunks.tool_calls[0]["args"] - ) - if parsed_message: - return VisualizationMessage(answer=parsed_message.query) - elif langgraph_state["langgraph_node"] == AssistantNodeName.SUMMARIZER: - self._chunks += langchain_message # type: ignore - return AssistantMessage(content=self._chunks.content) - elif is_task_started_update(update): - _, task_update = update - node_name = task_update["payload"]["name"] # type: ignore - node_input = task_update["payload"]["input"] # type: ignore - if reasoning_message := self._node_to_reasoning_message(node_name, node_input): - return reasoning_message + return node_val.messages[0] + + return None + + def _process_message_update(self, update: GraphMessageUpdateTuple) -> BaseModel | None: + langchain_message, langgraph_state = update[1] + if isinstance(langchain_message, AIMessageChunk): + if langgraph_state["langgraph_node"] in VISUALIZATION_NODES.keys(): + self._chunks += langchain_message # type: ignore + parsed_message = VISUALIZATION_NODES[langgraph_state["langgraph_node"]].parse_output( + self._chunks.tool_calls[0]["args"] + ) + if parsed_message: + initiator_id = self._state.start_id if self._state is not None else None + return VisualizationMessage(answer=parsed_message.query, initiator=initiator_id) + elif langgraph_state["langgraph_node"] == AssistantNodeName.SUMMARIZER: + self._chunks += langchain_message # type: ignore + return AssistantMessage(content=self._chunks.content) + return None + + def _process_task_started_update(self, update: GraphTaskStartedUpdateTuple) -> BaseModel | None: + _, task_update = update + node_name = task_update["payload"]["name"] # type: ignore + node_input = task_update["payload"]["input"] # type: ignore + if reasoning_message := self._node_to_reasoning_message(node_name, node_input): + return reasoning_message return None def _serialize_message(self, message: BaseModel) -> str: @@ -224,9 +276,15 @@ def _serialize_message(self, message: BaseModel) -> str: output += f"event: {AssistantEventType.MESSAGE}\n" return output + f"data: {message.model_dump_json(exclude_none=True)}\n\n" - def _report_conversation(self, message: Optional[VisualizationMessage]): - human_message = self._conversation.messages[-1].root - if self._user and message and isinstance(human_message, HumanMessage): + def _serialize_conversation(self) -> str: + output = f"event: {AssistantEventType.CONVERSATION}\n" + json_conversation = json.dumps({"id": str(self._conversation.id)}) + output += f"data: {json_conversation}\n\n" + return output + + def _report_conversation_state(self, message: Optional[VisualizationMessage]): + human_message = self._latest_message + if self._user and message: report_user_action( self._user, "chat with ai", diff --git a/ee/hogai/django_checkpoint/__init__.py b/ee/hogai/django_checkpoint/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/ee/hogai/django_checkpoint/checkpointer.py b/ee/hogai/django_checkpoint/checkpointer.py new file mode 100644 index 0000000000000..78817dca9df76 --- /dev/null +++ b/ee/hogai/django_checkpoint/checkpointer.py @@ -0,0 +1,309 @@ +import json +import random +import threading +from collections.abc import Iterable, Iterator, Sequence +from typing import Any, Optional, cast + +from django.db import transaction +from django.db.models import Q +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.base import ( + WRITES_IDX_MAP, + BaseCheckpointSaver, + ChannelVersions, + Checkpoint, + CheckpointMetadata, + CheckpointTuple, + PendingWrite, + get_checkpoint_id, +) +from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer +from langgraph.checkpoint.serde.types import ChannelProtocol + +from ee.models.assistant import ConversationCheckpoint, ConversationCheckpointBlob, ConversationCheckpointWrite + + +class DjangoCheckpointer(BaseCheckpointSaver[str]): + jsonplus_serde = JsonPlusSerializer() + _lock: threading.Lock + + def __init__(self, *args): + super().__init__(*args) + self._lock = threading.Lock() + + def _load_writes(self, writes: Sequence[ConversationCheckpointWrite]) -> list[PendingWrite]: + return ( + [ + ( + str(checkpoint_write.task_id), + checkpoint_write.channel, + self.serde.loads_typed((checkpoint_write.type, checkpoint_write.blob)), + ) + for checkpoint_write in writes + if checkpoint_write.type is not None and checkpoint_write.blob is not None + ] + if writes + else [] + ) + + def _load_json(self, obj: Any): + return self.jsonplus_serde.loads(self.jsonplus_serde.dumps(obj)) + + def _dump_json(self, obj: Any) -> dict[str, Any]: + serialized_metadata = self.jsonplus_serde.dumps(obj) + # NOTE: we're using JSON serializer (not msgpack), so we need to remove null characters before writing + nulls_removed = serialized_metadata.decode().replace("\\u0000", "") + return json.loads(nulls_removed) + + def _get_checkpoint_qs( + self, + config: Optional[RunnableConfig], + filter: Optional[dict[str, Any]], + before: Optional[RunnableConfig], + ): + query = Q() + + # construct predicate for config filter + if config and "configurable" in config: + thread_id = config["configurable"].get("thread_id") + query &= Q(thread_id=thread_id) + checkpoint_ns = config["configurable"].get("checkpoint_ns") + if checkpoint_ns is not None: + query &= Q(checkpoint_ns=checkpoint_ns) + if checkpoint_id := get_checkpoint_id(config): + query &= Q(id=checkpoint_id) + + # construct predicate for metadata filter + if filter: + query &= Q(metadata__contains=filter) + + # construct predicate for `before` + if before is not None: + query &= Q(id__lt=get_checkpoint_id(before)) + + return ConversationCheckpoint.objects.filter(query).order_by("-id") + + def _get_checkpoint_channel_values( + self, checkpoint: ConversationCheckpoint + ) -> Iterable[ConversationCheckpointBlob]: + if not checkpoint.checkpoint: + return [] + loaded_checkpoint = self._load_json(checkpoint.checkpoint) + if "channel_versions" not in loaded_checkpoint: + return [] + query = Q() + for channel, version in loaded_checkpoint["channel_versions"].items(): + query |= Q(channel=channel, version=version) + return checkpoint.blobs.filter(query) + + def list( + self, + config: Optional[RunnableConfig], + *, + filter: Optional[dict[str, Any]] = None, + before: Optional[RunnableConfig] = None, + limit: Optional[int] = None, + ) -> Iterator[CheckpointTuple]: + """List checkpoints from the database. + + This method retrieves a list of checkpoint tuples from the Postgres database based + on the provided config. The checkpoints are ordered by checkpoint ID in descending order (newest first). + + Args: + config (RunnableConfig): The config to use for listing the checkpoints. + filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata. Defaults to None. + before (Optional[RunnableConfig]): If provided, only checkpoints before the specified checkpoint ID are returned. Defaults to None. + limit (Optional[int]): The maximum number of checkpoints to return. Defaults to None. + + Yields: + Iterator[CheckpointTuple]: An iterator of checkpoint tuples. + """ + qs = self._get_checkpoint_qs(config, filter, before) + if limit: + qs = qs[:limit] + + for checkpoint in qs: + channel_values = self._get_checkpoint_channel_values(checkpoint) + loaded_checkpoint: Checkpoint = self._load_json(checkpoint.checkpoint) + + checkpoint_dict: Checkpoint = { + **loaded_checkpoint, + "pending_sends": [ + self.serde.loads_typed((checkpoint_write.type, checkpoint_write.blob)) + for checkpoint_write in checkpoint.pending_sends + ], + "channel_values": { + checkpoint_blob.channel: self.serde.loads_typed((checkpoint_blob.type, checkpoint_blob.blob)) + for checkpoint_blob in channel_values + if checkpoint_blob.type is not None + and checkpoint_blob.type != "empty" + and checkpoint_blob.blob is not None + }, + } + + yield CheckpointTuple( + { + "configurable": { + "thread_id": checkpoint.thread_id, + "checkpoint_ns": checkpoint.checkpoint_ns, + "checkpoint_id": checkpoint.id, + } + }, + checkpoint_dict, + self._load_json(checkpoint.metadata), + ( + { + "configurable": { + "thread_id": checkpoint.thread_id, + "checkpoint_ns": checkpoint.checkpoint_ns, + "checkpoint_id": checkpoint.parent_checkpoint_id, + } + } + if checkpoint.parent_checkpoint + else None + ), + self._load_writes(checkpoint.pending_writes), + ) + + def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: + """Get a checkpoint tuple from the database. + + This method retrieves a checkpoint tuple from the Postgres database based on the + provided config. If the config contains a "checkpoint_id" key, the checkpoint with + the matching thread ID and timestamp is retrieved. Otherwise, the latest checkpoint + for the given thread ID is retrieved. + + Args: + config (RunnableConfig): The config to use for retrieving the checkpoint. + + Returns: + Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found. + """ + return next(self.list(config), None) + + def put( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + """Save a checkpoint to the database. + + This method saves a checkpoint to the Postgres database. The checkpoint is associated + with the provided config and its parent config (if any). + + Args: + config (RunnableConfig): The config to associate with the checkpoint. + checkpoint (Checkpoint): The checkpoint to save. + metadata (CheckpointMetadata): Additional metadata to save with the checkpoint. + new_versions (ChannelVersions): New channel versions as of this write. + + Returns: + RunnableConfig: Updated configuration after storing the checkpoint. + """ + configurable = config["configurable"] + thread_id: str = configurable["thread_id"] + checkpoint_id = get_checkpoint_id(config) + checkpoint_ns: str | None = configurable.get("checkpoint_ns") or "" + + checkpoint_copy = cast(dict[str, Any], checkpoint.copy()) + channel_values = checkpoint_copy.pop("channel_values", {}) + + next_config: RunnableConfig = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint["id"], + } + } + + with self._lock, transaction.atomic(): + updated_checkpoint, _ = ConversationCheckpoint.objects.update_or_create( + id=checkpoint["id"], + thread_id=thread_id, + checkpoint_ns=checkpoint_ns, + defaults={ + "parent_checkpoint_id": checkpoint_id, + "checkpoint": self._dump_json({**checkpoint_copy, "pending_sends": []}), + "metadata": self._dump_json(metadata), + }, + ) + + blobs = [] + for channel, version in new_versions.items(): + type, blob = ( + self.serde.dumps_typed(channel_values[channel]) if channel in channel_values else ("empty", None) + ) + blobs.append( + ConversationCheckpointBlob( + checkpoint=updated_checkpoint, + channel=channel, + version=str(version), + type=type, + blob=blob, + ) + ) + + ConversationCheckpointBlob.objects.bulk_create(blobs, ignore_conflicts=True) + return next_config + + def put_writes( + self, + config: RunnableConfig, + writes: Sequence[tuple[str, Any]], + task_id: str, + ) -> None: + """Store intermediate writes linked to a checkpoint. + + This method saves intermediate writes associated with a checkpoint to the Postgres database. + + Args: + config (RunnableConfig): Configuration of the related checkpoint. + writes (List[Tuple[str, Any]]): List of writes to store. + task_id (str): Identifier for the task creating the writes. + """ + configurable = config["configurable"] + thread_id: str = configurable["thread_id"] + checkpoint_id = get_checkpoint_id(config) + checkpoint_ns: str | None = configurable.get("checkpoint_ns") or "" + + with self._lock, transaction.atomic(): + # `put_writes` and `put` are concurrently called without guaranteeing the call order + # so we need to ensure the checkpoint is created before creating writes. + # Thread.lock() will prevent race conditions though to the same checkpoints within a single pod. + checkpoint, _ = ConversationCheckpoint.objects.get_or_create( + id=checkpoint_id, thread_id=thread_id, checkpoint_ns=checkpoint_ns + ) + + writes_to_create = [] + for idx, (channel, value) in enumerate(writes): + type, blob = self.serde.dumps_typed(value) + writes_to_create.append( + ConversationCheckpointWrite( + checkpoint=checkpoint, + task_id=task_id, + idx=idx, + channel=channel, + type=type, + blob=blob, + ) + ) + + ConversationCheckpointWrite.objects.bulk_create( + writes_to_create, + update_conflicts=all(w[0] in WRITES_IDX_MAP for w in writes), + unique_fields=["checkpoint", "task_id", "idx"], + update_fields=["channel", "type", "blob"], + ) + + def get_next_version(self, current: Optional[str | int], channel: ChannelProtocol) -> str: + if current is None: + current_v = 0 + elif isinstance(current, int): + current_v = current + else: + current_v = int(current.split(".")[0]) + next_v = current_v + 1 + next_h = random.random() + return f"{next_v:032}.{next_h:016}" diff --git a/ee/hogai/django_checkpoint/test/test_checkpointer.py b/ee/hogai/django_checkpoint/test/test_checkpointer.py new file mode 100644 index 0000000000000..2f8fd7f4a60ed --- /dev/null +++ b/ee/hogai/django_checkpoint/test/test_checkpointer.py @@ -0,0 +1,274 @@ +# type: ignore + +from typing import Any, TypedDict + +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.base import ( + Checkpoint, + CheckpointMetadata, + create_checkpoint, + empty_checkpoint, +) +from langgraph.checkpoint.base.id import uuid6 +from langgraph.errors import NodeInterrupt +from langgraph.graph import END, START +from langgraph.graph.state import CompiledStateGraph, StateGraph + +from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer +from ee.models.assistant import ( + Conversation, + ConversationCheckpoint, + ConversationCheckpointBlob, + ConversationCheckpointWrite, +) +from posthog.test.base import NonAtomicBaseTest + + +class TestDjangoCheckpointer(NonAtomicBaseTest): + CLASS_DATA_LEVEL_SETUP = False + + def _build_graph(self, checkpointer: DjangoCheckpointer): + class State(TypedDict): + val: int + + graph = StateGraph(State) + + def handle_node1(state: State) -> State: + if state["val"] == 1: + raise NodeInterrupt("test") + return {"val": state["val"] + 1} + + graph.add_node("node1", handle_node1) + graph.add_node("node2", lambda state: state) + + graph.add_edge(START, "node1") + graph.add_edge("node1", "node2") + graph.add_edge("node2", END) + + return graph.compile(checkpointer=checkpointer) + + def test_saver(self): + thread1 = Conversation.objects.create(user=self.user, team=self.team) + thread2 = Conversation.objects.create(user=self.user, team=self.team) + + config_1: RunnableConfig = { + "configurable": { + "thread_id": thread1.id, + "checkpoint_ns": "", + } + } + chkpnt_1: Checkpoint = empty_checkpoint() + + config_2: RunnableConfig = { + "configurable": { + "thread_id": thread2.id, + "checkpoint_ns": "", + } + } + chkpnt_2: Checkpoint = create_checkpoint(chkpnt_1, {}, 1) + + config_3: RunnableConfig = { + "configurable": { + "thread_id": thread2.id, + "checkpoint_id": chkpnt_2["id"], + "checkpoint_ns": "inner", + } + } + chkpnt_3: Checkpoint = empty_checkpoint() + + metadata_1: CheckpointMetadata = { + "source": "input", + "step": 2, + "writes": {}, + "score": 1, + } + metadata_2: CheckpointMetadata = { + "source": "loop", + "step": 1, + "writes": {"foo": "bar"}, + "score": None, + } + metadata_3: CheckpointMetadata = {} + + test_data = { + "configs": [config_1, config_2, config_3], + "checkpoints": [chkpnt_1, chkpnt_2, chkpnt_3], + "metadata": [metadata_1, metadata_2, metadata_3], + } + + saver = DjangoCheckpointer() + + configs = test_data["configs"] + checkpoints = test_data["checkpoints"] + metadata = test_data["metadata"] + + saver.put(configs[0], checkpoints[0], metadata[0], {}) + saver.put(configs[1], checkpoints[1], metadata[1], {}) + saver.put(configs[2], checkpoints[2], metadata[2], {}) + + # call method / assertions + query_1 = {"source": "input"} # search by 1 key + query_2 = { + "step": 1, + "writes": {"foo": "bar"}, + } # search by multiple keys + query_3: dict[str, Any] = {} # search by no keys, return all checkpoints + query_4 = {"source": "update", "step": 1} # no match + + search_results_1 = list(saver.list(None, filter=query_1)) + assert len(search_results_1) == 1 + assert search_results_1[0].metadata == metadata[0] + + search_results_2 = list(saver.list(None, filter=query_2)) + assert len(search_results_2) == 1 + assert search_results_2[0].metadata == metadata[1] + + search_results_3 = list(saver.list(None, filter=query_3)) + assert len(search_results_3) == 3 + + search_results_4 = list(saver.list(None, filter=query_4)) + assert len(search_results_4) == 0 + + # search by config (defaults to checkpoints across all namespaces) + search_results_5 = list(saver.list({"configurable": {"thread_id": thread2.id}})) + assert len(search_results_5) == 2 + assert { + search_results_5[0].config["configurable"]["checkpoint_ns"], + search_results_5[1].config["configurable"]["checkpoint_ns"], + } == {"", "inner"} + + def test_channel_versions(self): + thread1 = Conversation.objects.create(user=self.user, team=self.team) + + chkpnt = { + "v": 1, + "ts": "2024-07-31T20:14:19.804150+00:00", + "id": str(uuid6(clock_seq=-2)), + "channel_values": { + "post": "hog", + "node": "node", + }, + "channel_versions": { + "__start__": 2, + "my_key": 3, + "start:node": 3, + "node": 3, + }, + "versions_seen": { + "__input__": {}, + "__start__": {"__start__": 1}, + "node": {"start:node": 2}, + }, + "pending_sends": [], + } + metadata = {"meta": "key"} + + write_config = {"configurable": {"thread_id": thread1.id, "checkpoint_ns": ""}} + read_config = {"configurable": {"thread_id": thread1.id}} + + saver = DjangoCheckpointer() + saver.put(write_config, chkpnt, metadata, {}) + + checkpoint = ConversationCheckpoint.objects.first() + self.assertIsNotNone(checkpoint) + self.assertEqual(checkpoint.thread, thread1) + self.assertEqual(checkpoint.checkpoint_ns, "") + self.assertEqual(str(checkpoint.id), chkpnt["id"]) + self.assertIsNone(checkpoint.parent_checkpoint) + chkpnt.pop("channel_values") + self.assertEqual(checkpoint.checkpoint, chkpnt) + self.assertEqual(checkpoint.metadata, metadata) + + checkpoints = list(saver.list(read_config)) + self.assertEqual(len(checkpoints), 1) + + checkpoint = saver.get(read_config) + self.assertEqual(checkpoint, checkpoints[0].checkpoint) + + def test_put_copies_checkpoint(self): + thread1 = Conversation.objects.create(user=self.user, team=self.team) + chkpnt = { + "v": 1, + "ts": "2024-07-31T20:14:19.804150+00:00", + "id": str(uuid6(clock_seq=-2)), + "channel_values": { + "post": "hog", + "node": "node", + }, + "channel_versions": { + "__start__": 2, + "my_key": 3, + "start:node": 3, + "node": 3, + }, + "versions_seen": { + "__input__": {}, + "__start__": {"__start__": 1}, + "node": {"start:node": 2}, + }, + "pending_sends": [], + } + metadata = {"meta": "key"} + write_config = {"configurable": {"thread_id": thread1.id, "checkpoint_ns": ""}} + saver = DjangoCheckpointer() + saver.put(write_config, chkpnt, metadata, {}) + self.assertIn("channel_values", chkpnt) + + def test_concurrent_puts_and_put_writes(self): + graph: CompiledStateGraph = self._build_graph(DjangoCheckpointer()) + thread = Conversation.objects.create(user=self.user, team=self.team) + config = {"configurable": {"thread_id": str(thread.id)}} + graph.invoke( + {"val": 0}, + config=config, + ) + self.assertEqual(len(ConversationCheckpoint.objects.all()), 4) + self.assertEqual(len(ConversationCheckpointBlob.objects.all()), 10) + self.assertEqual(len(ConversationCheckpointWrite.objects.all()), 6) + + def test_resuming(self): + checkpointer = DjangoCheckpointer() + graph: CompiledStateGraph = self._build_graph(checkpointer) + thread = Conversation.objects.create(user=self.user, team=self.team) + config = {"configurable": {"thread_id": str(thread.id)}} + + graph.invoke( + {"val": 1}, + config=config, + ) + snapshot = graph.get_state(config) + self.assertIsNotNone(snapshot.next) + self.assertEqual(snapshot.tasks[0].interrupts[0].value, "test") + + self.assertEqual(len(ConversationCheckpoint.objects.all()), 2) + self.assertEqual(len(ConversationCheckpointBlob.objects.all()), 4) + self.assertEqual(len(ConversationCheckpointWrite.objects.all()), 3) + self.assertEqual(len(list(checkpointer.list(config))), 2) + + latest_checkpoint = ConversationCheckpoint.objects.last() + latest_write = ConversationCheckpointWrite.objects.filter(checkpoint=latest_checkpoint).first() + actual_checkpoint = checkpointer.get_tuple(config) + self.assertIsNotNone(actual_checkpoint) + self.assertIsNotNone(latest_write) + self.assertEqual(len(latest_checkpoint.writes.all()), 1) + blobs = list(latest_checkpoint.blobs.all()) + self.assertEqual(len(blobs), 3) + self.assertEqual(actual_checkpoint.checkpoint["id"], str(latest_checkpoint.id)) + self.assertEqual(len(actual_checkpoint.pending_writes), 1) + self.assertEqual(actual_checkpoint.pending_writes[0][0], str(latest_write.task_id)) + + graph.update_state(config, {"val": 2}) + # add the value update checkpoint + self.assertEqual(len(ConversationCheckpoint.objects.all()), 3) + self.assertEqual(len(ConversationCheckpointBlob.objects.all()), 6) + self.assertEqual(len(ConversationCheckpointWrite.objects.all()), 5) + self.assertEqual(len(list(checkpointer.list(config))), 3) + + res = graph.invoke(None, config=config) + self.assertEqual(len(ConversationCheckpoint.objects.all()), 5) + self.assertEqual(len(ConversationCheckpointBlob.objects.all()), 12) + self.assertEqual(len(ConversationCheckpointWrite.objects.all()), 9) + self.assertEqual(len(list(checkpointer.list(config))), 5) + self.assertEqual(res, {"val": 3}) + snapshot = graph.get_state(config) + self.assertFalse(snapshot.next) diff --git a/ee/hogai/eval/tests/test_eval_funnel_generator.py b/ee/hogai/eval/tests/test_eval_funnel_generator.py index cd7e93b260ae9..4d7876ca6f73c 100644 --- a/ee/hogai/eval/tests/test_eval_funnel_generator.py +++ b/ee/hogai/eval/tests/test_eval_funnel_generator.py @@ -1,9 +1,11 @@ +from typing import cast + from langgraph.graph.state import CompiledStateGraph from ee.hogai.assistant import AssistantGraph from ee.hogai.eval.utils import EvalBaseTest -from ee.hogai.utils import AssistantNodeName -from posthog.schema import AssistantFunnelsQuery, HumanMessage +from ee.hogai.utils.types import AssistantNodeName, AssistantState +from posthog.schema import AssistantFunnelsQuery, HumanMessage, VisualizationMessage class TestEvalFunnelGenerator(EvalBaseTest): @@ -14,8 +16,11 @@ def _call_node(self, query: str, plan: str) -> AssistantFunnelsQuery: .add_funnel_generator(AssistantNodeName.END) .compile() ) - state = graph.invoke({"messages": [HumanMessage(content=query)], "plan": plan}) - return state["messages"][-1].answer + state = graph.invoke( + AssistantState(messages=[HumanMessage(content=query)], plan=plan), + self._get_config(), + ) + return cast(VisualizationMessage, AssistantState.model_validate(state).messages[-1]).answer def test_node_replaces_equals_with_contains(self): query = "what is the conversion rate from a page view to sign up for users with name John?" diff --git a/ee/hogai/eval/tests/test_eval_funnel_planner.py b/ee/hogai/eval/tests/test_eval_funnel_planner.py index 3760961f9bb03..9adbd75e77c6c 100644 --- a/ee/hogai/eval/tests/test_eval_funnel_planner.py +++ b/ee/hogai/eval/tests/test_eval_funnel_planner.py @@ -5,7 +5,7 @@ from ee.hogai.assistant import AssistantGraph from ee.hogai.eval.utils import EvalBaseTest -from ee.hogai.utils import AssistantNodeName +from ee.hogai.utils.types import AssistantNodeName, AssistantState from posthog.schema import HumanMessage @@ -40,8 +40,11 @@ def _call_node(self, query): .add_funnel_planner(AssistantNodeName.END) .compile() ) - state = graph.invoke({"messages": [HumanMessage(content=query)]}) - return state["plan"] + state = graph.invoke( + AssistantState(messages=[HumanMessage(content=query)]), + self._get_config(), + ) + return AssistantState.model_validate(state).plan or "" def test_basic_funnel(self): query = "what was the conversion from a page view to sign up?" diff --git a/ee/hogai/eval/tests/test_eval_router.py b/ee/hogai/eval/tests/test_eval_router.py index 25a84769dbfc8..c1307e9d40f00 100644 --- a/ee/hogai/eval/tests/test_eval_router.py +++ b/ee/hogai/eval/tests/test_eval_router.py @@ -1,8 +1,10 @@ +from typing import cast + from langgraph.graph.state import CompiledStateGraph from ee.hogai.assistant import AssistantGraph from ee.hogai.eval.utils import EvalBaseTest -from ee.hogai.utils import AssistantNodeName +from ee.hogai.utils.types import AssistantNodeName, AssistantState from posthog.schema import HumanMessage, RouterMessage @@ -15,8 +17,11 @@ def _call_node(self, query: str | list): .compile() ) messages = [HumanMessage(content=query)] if isinstance(query, str) else query - state = graph.invoke({"messages": messages}) - return state["messages"][-1].content + state = graph.invoke( + AssistantState(messages=messages), + self._get_config(), + ) + return cast(RouterMessage, AssistantState.model_validate(state).messages[-1]).content def test_outputs_basic_trends_insight(self): query = "Show the $pageview trend" diff --git a/ee/hogai/eval/tests/test_eval_trends_generator.py b/ee/hogai/eval/tests/test_eval_trends_generator.py index c5341584ca2f7..496bbf0100b51 100644 --- a/ee/hogai/eval/tests/test_eval_trends_generator.py +++ b/ee/hogai/eval/tests/test_eval_trends_generator.py @@ -1,9 +1,11 @@ +from typing import cast + from langgraph.graph.state import CompiledStateGraph from ee.hogai.assistant import AssistantGraph from ee.hogai.eval.utils import EvalBaseTest -from ee.hogai.utils import AssistantNodeName -from posthog.schema import AssistantTrendsQuery, HumanMessage +from ee.hogai.utils.types import AssistantNodeName, AssistantState +from posthog.schema import AssistantTrendsQuery, HumanMessage, VisualizationMessage class TestEvalTrendsGenerator(EvalBaseTest): @@ -14,8 +16,11 @@ def _call_node(self, query: str, plan: str) -> AssistantTrendsQuery: .add_trends_generator(AssistantNodeName.END) .compile() ) - state = graph.invoke({"messages": [HumanMessage(content=query)], "plan": plan}) - return state["messages"][-1].answer + state = graph.invoke( + AssistantState(messages=[HumanMessage(content=query)], plan=plan), + self._get_config(), + ) + return cast(VisualizationMessage, AssistantState.model_validate(state).messages[-1]).answer def test_node_replaces_equals_with_contains(self): query = "what is pageview trend for users with name John?" diff --git a/ee/hogai/eval/tests/test_eval_trends_planner.py b/ee/hogai/eval/tests/test_eval_trends_planner.py index e7ea741d03687..d4fbff456a91c 100644 --- a/ee/hogai/eval/tests/test_eval_trends_planner.py +++ b/ee/hogai/eval/tests/test_eval_trends_planner.py @@ -5,7 +5,7 @@ from ee.hogai.assistant import AssistantGraph from ee.hogai.eval.utils import EvalBaseTest -from ee.hogai.utils import AssistantNodeName +from ee.hogai.utils.types import AssistantNodeName, AssistantState from posthog.schema import HumanMessage @@ -40,8 +40,11 @@ def _call_node(self, query): .add_trends_planner(AssistantNodeName.END) .compile() ) - state = graph.invoke({"messages": [HumanMessage(content=query)]}) - return state["plan"] + state = graph.invoke( + AssistantState(messages=[HumanMessage(content=query)]), + self._get_config(), + ) + return AssistantState.model_validate(state).plan or "" def test_no_excessive_property_filters(self): query = "Show the $pageview trend" diff --git a/ee/hogai/eval/utils.py b/ee/hogai/eval/utils.py index 1e50a75daefa2..6e03c4cfafa9f 100644 --- a/ee/hogai/eval/utils.py +++ b/ee/hogai/eval/utils.py @@ -3,15 +3,25 @@ import pytest from django.test import override_settings from flaky import flaky +from langchain_core.runnables import RunnableConfig +from ee.models.assistant import Conversation from posthog.demo.matrix.manager import MatrixManager from posthog.tasks.demo_create_data import HedgeboxMatrix -from posthog.test.base import BaseTest +from posthog.test.base import NonAtomicBaseTest @pytest.mark.skipif(os.environ.get("DEEPEVAL") != "YES", reason="Only runs for the assistant evaluation") @flaky(max_runs=3, min_passes=1) -class EvalBaseTest(BaseTest): +class EvalBaseTest(NonAtomicBaseTest): + def _get_config(self) -> RunnableConfig: + conversation = Conversation.objects.create(team=self.team, user=self.user) + return { + "configurable": { + "thread_id": conversation.id, + } + } + @classmethod def setUpTestData(cls): super().setUpTestData() diff --git a/ee/hogai/funnels/nodes.py b/ee/hogai/funnels/nodes.py index a55bc223847f2..6f71305e0b796 100644 --- a/ee/hogai/funnels/nodes.py +++ b/ee/hogai/funnels/nodes.py @@ -6,12 +6,12 @@ from ee.hogai.schema_generator.nodes import SchemaGeneratorNode, SchemaGeneratorToolsNode from ee.hogai.schema_generator.utils import SchemaGeneratorOutput from ee.hogai.taxonomy_agent.nodes import TaxonomyAgentPlannerNode, TaxonomyAgentPlannerToolsNode -from ee.hogai.utils import AssistantState +from ee.hogai.utils.types import AssistantState, PartialAssistantState from posthog.schema import AssistantFunnelsQuery class FunnelPlannerNode(TaxonomyAgentPlannerNode): - def run(self, state: AssistantState, config: RunnableConfig) -> AssistantState: + def run(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState: toolkit = FunnelsTaxonomyAgentToolkit(self._team) prompt = ChatPromptTemplate.from_messages( [ @@ -23,7 +23,7 @@ def run(self, state: AssistantState, config: RunnableConfig) -> AssistantState: class FunnelPlannerToolsNode(TaxonomyAgentPlannerToolsNode): - def run(self, state: AssistantState, config: RunnableConfig) -> AssistantState: + def run(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState: toolkit = FunnelsTaxonomyAgentToolkit(self._team) return super()._run_with_toolkit(state, toolkit, config=config) @@ -36,7 +36,7 @@ class FunnelGeneratorNode(SchemaGeneratorNode[AssistantFunnelsQuery]): OUTPUT_MODEL = FunnelsSchemaGeneratorOutput OUTPUT_SCHEMA = FUNNEL_SCHEMA - def run(self, state: AssistantState, config: RunnableConfig) -> AssistantState: + def run(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState: prompt = ChatPromptTemplate.from_messages( [ ("system", FUNNEL_SYSTEM_PROMPT), diff --git a/ee/hogai/funnels/prompts.py b/ee/hogai/funnels/prompts.py index b2deec894a070..3808809c173a7 100644 --- a/ee/hogai/funnels/prompts.py +++ b/ee/hogai/funnels/prompts.py @@ -12,6 +12,8 @@ {{react_format}} +{{react_human_in_the_loop}} + Below you will find information on how to correctly discover the taxonomy of the user's data. diff --git a/ee/hogai/funnels/test/test_nodes.py b/ee/hogai/funnels/test/test_nodes.py index 5c65b14110599..4f4e9fca0e5d4 100644 --- a/ee/hogai/funnels/test/test_nodes.py +++ b/ee/hogai/funnels/test/test_nodes.py @@ -4,6 +4,7 @@ from langchain_core.runnables import RunnableLambda from ee.hogai.funnels.nodes import FunnelGeneratorNode, FunnelsSchemaGeneratorOutput +from ee.hogai.utils.types import AssistantState, PartialAssistantState from posthog.schema import ( AssistantFunnelsQuery, HumanMessage, @@ -15,6 +16,7 @@ @override_settings(IN_UNIT_TESTING=True) class TestFunnelsGeneratorNode(ClickhouseTestMixin, APIBaseTest): def setUp(self): + super().setUp() self.schema = AssistantFunnelsQuery(series=[]) def test_node_runs(self): @@ -24,16 +26,13 @@ def test_node_runs(self): lambda _: FunnelsSchemaGeneratorOutput(query=self.schema).model_dump() ) new_state = node.run( - { - "messages": [HumanMessage(content="Text")], - "plan": "Plan", - }, + AssistantState(messages=[HumanMessage(content="Text")], plan="Plan"), {}, ) self.assertEqual( new_state, - { - "messages": [VisualizationMessage(answer=self.schema, plan="Plan", done=True)], - "intermediate_steps": None, - }, + PartialAssistantState( + messages=[VisualizationMessage(answer=self.schema, plan="Plan", id=new_state.messages[0].id)], + intermediate_steps=None, + ), ) diff --git a/ee/hogai/funnels/toolkit.py b/ee/hogai/funnels/toolkit.py index 8d6407027aac1..ae603519cc331 100644 --- a/ee/hogai/funnels/toolkit.py +++ b/ee/hogai/funnels/toolkit.py @@ -1,5 +1,5 @@ from ee.hogai.taxonomy_agent.toolkit import TaxonomyAgentToolkit, ToolkitTool -from ee.hogai.utils import dereference_schema +from ee.hogai.utils.helpers import dereference_schema from posthog.schema import AssistantFunnelsQuery diff --git a/ee/hogai/graph.py b/ee/hogai/graph.py index 79e5f914097ce..bf961d6bb9aa8 100644 --- a/ee/hogai/graph.py +++ b/ee/hogai/graph.py @@ -1,10 +1,10 @@ from collections.abc import Hashable from typing import Optional, cast -from langfuse.callback import CallbackHandler +from langchain_core.runnables.base import RunnableLike from langgraph.graph.state import StateGraph -from ee import settings +from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer from ee.hogai.funnels.nodes import ( FunnelGeneratorNode, FunnelGeneratorToolsNode, @@ -19,15 +19,10 @@ TrendsPlannerNode, TrendsPlannerToolsNode, ) -from ee.hogai.utils import AssistantNodeName, AssistantState +from ee.hogai.utils.types import AssistantNodeName, AssistantState from posthog.models.team.team import Team -if settings.LANGFUSE_PUBLIC_KEY: - langfuse_handler = CallbackHandler( - public_key=settings.LANGFUSE_PUBLIC_KEY, secret_key=settings.LANGFUSE_SECRET_KEY, host=settings.LANGFUSE_HOST - ) -else: - langfuse_handler = None +checkpointer = DjangoCheckpointer() class AssistantGraph: @@ -45,10 +40,14 @@ def add_edge(self, from_node: AssistantNodeName, to_node: AssistantNodeName): self._graph.add_edge(from_node, to_node) return self + def add_node(self, node: AssistantNodeName, action: RunnableLike): + self._graph.add_node(node, action) + return self + def compile(self): if not self._has_start_node: raise ValueError("Start node not added to the graph") - return self._graph.compile() + return self._graph.compile(checkpointer=checkpointer) def add_start(self): return self.add_edge(AssistantNodeName.START, AssistantNodeName.ROUTER) diff --git a/ee/hogai/router/nodes.py b/ee/hogai/router/nodes.py index c9151faaabc29..f6aeacdebbe6b 100644 --- a/ee/hogai/router/nodes.py +++ b/ee/hogai/router/nodes.py @@ -1,4 +1,5 @@ from typing import Literal, cast +from uuid import uuid4 from langchain_core.messages import AIMessage as LangchainAIMessage, BaseMessage from langchain_core.prompts import ChatPromptTemplate @@ -11,7 +12,8 @@ ROUTER_SYSTEM_PROMPT, ROUTER_USER_PROMPT, ) -from ee.hogai.utils import AssistantState, AssistantNode +from ee.hogai.utils.nodes import AssistantNode +from ee.hogai.utils.types import AssistantState, PartialAssistantState from posthog.schema import HumanMessage, RouterMessage RouteName = Literal["trends", "funnel"] @@ -22,7 +24,7 @@ class RouterOutput(BaseModel): class RouterNode(AssistantNode): - def run(self, state: AssistantState, config: RunnableConfig) -> AssistantState: + def run(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState: prompt = ChatPromptTemplate.from_messages( [ ("system", ROUTER_SYSTEM_PROMPT), @@ -31,10 +33,10 @@ def run(self, state: AssistantState, config: RunnableConfig) -> AssistantState: ) + self._construct_messages(state) chain = prompt | self._model output: RouterOutput = chain.invoke({}, config) - return {"messages": [RouterMessage(content=output.visualization_type)]} + return PartialAssistantState(messages=[RouterMessage(content=output.visualization_type, id=str(uuid4()))]) def router(self, state: AssistantState) -> RouteName: - last_message = state["messages"][-1] + last_message = state.messages[-1] if isinstance(last_message, RouterMessage): return cast(RouteName, last_message.content) raise ValueError("Invalid route.") @@ -47,7 +49,7 @@ def _model(self): def _construct_messages(self, state: AssistantState): history: list[BaseMessage] = [] - for message in state["messages"]: + for message in state.messages: if isinstance(message, HumanMessage): history += ChatPromptTemplate.from_messages( [("user", ROUTER_USER_PROMPT.strip())], template_format="mustache" diff --git a/ee/hogai/router/test/test_nodes.py b/ee/hogai/router/test/test_nodes.py index 06014fb0b9f59..53074a381b804 100644 --- a/ee/hogai/router/test/test_nodes.py +++ b/ee/hogai/router/test/test_nodes.py @@ -2,11 +2,11 @@ from unittest.mock import patch from django.test import override_settings -from langchain_core.messages import AIMessage as LangchainAIMessage -from langchain_core.messages import HumanMessage as LangchainHumanMessage +from langchain_core.messages import AIMessage as LangchainAIMessage, HumanMessage as LangchainHumanMessage from langchain_core.runnables import RunnableLambda from ee.hogai.router.nodes import RouterNode, RouterOutput +from ee.hogai.utils.types import AssistantState, PartialAssistantState from posthog.schema import ( HumanMessage, RouterMessage, @@ -19,7 +19,7 @@ class TestRouterNode(ClickhouseTestMixin, APIBaseTest): def test_router(self): node = RouterNode(self.team) - state: Any = {"messages": [RouterMessage(content="trends")]} + state: Any = AssistantState(messages=[RouterMessage(content="trends")]) self.assertEqual(node.router(state), "trends") def test_node_runs(self): @@ -28,28 +28,36 @@ def test_node_runs(self): return_value=RunnableLambda(lambda _: RouterOutput(visualization_type="funnel")), ): node = RouterNode(self.team) - state: Any = {"messages": [HumanMessage(content="generate trends")]} - self.assertEqual(node.run(state, {}), {"messages": [RouterMessage(content="funnel")]}) + state: Any = AssistantState(messages=[HumanMessage(content="generate trends")]) + next_state = node.run(state, {}) + self.assertEqual( + next_state, + PartialAssistantState(messages=[RouterMessage(content="funnel", id=next_state.messages[0].id)]), + ) with patch( "ee.hogai.router.nodes.RouterNode._model", return_value=RunnableLambda(lambda _: RouterOutput(visualization_type="trends")), ): node = RouterNode(self.team) - state: Any = {"messages": [HumanMessage(content="generate trends")]} - self.assertEqual(node.run(state, {}), {"messages": [RouterMessage(content="trends")]}) + state: Any = AssistantState(messages=[HumanMessage(content="generate trends")]) + next_state = node.run(state, {}) + self.assertEqual( + next_state, + PartialAssistantState(messages=[RouterMessage(content="trends", id=next_state.messages[0].id)]), + ) def test_node_reconstructs_conversation(self): node = RouterNode(self.team) - state: Any = {"messages": [HumanMessage(content="generate trends")]} + state: Any = AssistantState(messages=[HumanMessage(content="generate trends")]) self.assertEqual(node._construct_messages(state), [LangchainHumanMessage(content="Question: generate trends")]) - state = { - "messages": [ + state = AssistantState( + messages=[ HumanMessage(content="generate trends"), RouterMessage(content="trends"), VisualizationMessage(), ] - } + ) self.assertEqual( node._construct_messages(state), [LangchainHumanMessage(content="Question: generate trends"), LangchainAIMessage(content="trends")], diff --git a/ee/hogai/schema_generator/nodes.py b/ee/hogai/schema_generator/nodes.py index f2d383d5c1e30..4bed02fd462cc 100644 --- a/ee/hogai/schema_generator/nodes.py +++ b/ee/hogai/schema_generator/nodes.py @@ -1,10 +1,16 @@ -import itertools import xml.etree.ElementTree as ET +from collections.abc import Sequence from functools import cached_property from typing import Generic, Optional, TypeVar +from uuid import uuid4 from langchain_core.agents import AgentAction -from langchain_core.messages import AIMessage as LangchainAssistantMessage, BaseMessage, merge_message_runs +from langchain_core.messages import ( + AIMessage as LangchainAssistantMessage, + BaseMessage, + HumanMessage as LangchainHumanMessage, + merge_message_runs, +) from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate from langchain_core.runnables import RunnableConfig from langchain_openai import ChatOpenAI @@ -23,10 +29,14 @@ QUESTION_PROMPT, ) from ee.hogai.schema_generator.utils import SchemaGeneratorOutput -from ee.hogai.utils import AssistantNode, AssistantState, filter_visualization_conversation +from ee.hogai.utils.helpers import find_last_message_of_type, slice_messages_to_conversation_start +from ee.hogai.utils.nodes import AssistantNode +from ee.hogai.utils.types import AssistantMessageUnion, AssistantState, PartialAssistantState from posthog.models.group_type_mapping import GroupTypeMapping from posthog.schema import ( + AssistantMessage, FailureMessage, + HumanMessage, VisualizationMessage, ) @@ -63,9 +73,10 @@ def _run_with_prompt( state: AssistantState, prompt: ChatPromptTemplate, config: Optional[RunnableConfig] = None, - ) -> AssistantState: - generated_plan = state.get("plan", "") - intermediate_steps = state.get("intermediate_steps") or [] + ) -> PartialAssistantState: + start_id = state.start_id + generated_plan = state.plan or "" + intermediate_steps = state.intermediate_steps or [] validation_error_message = intermediate_steps[-1][1] if intermediate_steps else None generation_prompt = prompt + self._construct_messages(state, validation_error_message=validation_error_message) @@ -79,35 +90,36 @@ def _run_with_prompt( except PydanticOutputParserException as e: # Generation step is expensive. After a second unsuccessful attempt, it's better to send a failure message. if len(intermediate_steps) >= 2: - return { - "messages": [ + return PartialAssistantState( + messages=[ FailureMessage( content=f"Oops! It looks like I’m having trouble generating this {self.INSIGHT_NAME} insight. Could you please try again?" ) ], - "intermediate_steps": None, - } + intermediate_steps=None, + ) - return { - "intermediate_steps": [ + return PartialAssistantState( + intermediate_steps=[ *intermediate_steps, (AgentAction("handle_incorrect_response", e.llm_output, e.validation_message), None), ], - } + ) - return { - "messages": [ + return PartialAssistantState( + messages=[ VisualizationMessage( plan=generated_plan, answer=message.query, - done=True, + initiator=start_id, + id=str(uuid4()), ) ], - "intermediate_steps": None, - } + intermediate_steps=None, + ) def router(self, state: AssistantState): - if state.get("intermediate_steps") is not None: + if state.intermediate_steps: return "tools" return "next" @@ -123,15 +135,25 @@ def _group_mapping_prompt(self) -> str: ) return ET.tostring(root, encoding="unicode") + def _get_human_viz_message_mapping(self, messages: Sequence[AssistantMessageUnion]) -> dict[str, int]: + mapping: dict[str, int] = {} + for idx, msg in enumerate(messages): + if isinstance(msg, VisualizationMessage) and msg.initiator is not None: + mapping[msg.initiator] = idx + return mapping + def _construct_messages( self, state: AssistantState, validation_error_message: Optional[str] = None ) -> list[BaseMessage]: """ Reconstruct the conversation for the generation. Take all previously generated questions, plans, and schemas, and return the history. """ - messages = state.get("messages", []) - generated_plan = state.get("plan", "") + messages = state.messages + generated_plan = state.plan + start_id = state.start_id + if start_id is not None: + messages = slice_messages_to_conversation_start(messages, start_id) if len(messages) == 0: return [] @@ -141,43 +163,61 @@ def _construct_messages( ) ] - human_messages, visualization_messages = filter_visualization_conversation(messages) - first_ai_message = True + msg_mapping = self._get_human_viz_message_mapping(messages) + initiator_message = messages[-1] + last_viz_message = find_last_message_of_type(messages, VisualizationMessage) + + for message in messages: + # The initial human message and the new plan are added to the end of the conversation. + if message == initiator_message: + continue + if isinstance(message, HumanMessage): + if message.id and (viz_message_idx := msg_mapping.get(message.id)): + # Plans go first. + viz_message = messages[viz_message_idx] + if isinstance(viz_message, VisualizationMessage): + conversation.append( + HumanMessagePromptTemplate.from_template(PLAN_PROMPT, template_format="mustache").format( + plan=viz_message.plan or "" + ) + ) - for idx, (human_message, ai_message) in enumerate( - itertools.zip_longest(human_messages, visualization_messages) - ): - # Plans go first - if ai_message: - conversation.append( - HumanMessagePromptTemplate.from_template( - PLAN_PROMPT if first_ai_message else NEW_PLAN_PROMPT, - template_format="mustache", - ).format(plan=ai_message.plan or "") - ) - first_ai_message = False - elif generated_plan: - conversation.append( - HumanMessagePromptTemplate.from_template( - PLAN_PROMPT if first_ai_message else NEW_PLAN_PROMPT, - template_format="mustache", - ).format(plan=generated_plan) + # Augment with the prompt previous initiator messages. + conversation.append( + HumanMessagePromptTemplate.from_template(QUESTION_PROMPT, template_format="mustache").format( + question=message.content + ) + ) + # Otherwise, just append the human message. + else: + conversation.append(LangchainHumanMessage(content=message.content)) + # Summary, human-in-the-loop messages. + elif isinstance(message, AssistantMessage): + conversation.append(LangchainAssistantMessage(content=message.content)) + + # Include only last generated schema because it doesn't need more context. + if last_viz_message: + conversation.append( + LangchainAssistantMessage( + content=last_viz_message.answer.model_dump_json() if last_viz_message.answer else "" ) - - # Then questions - if human_message: + ) + # Add the initiator message and the generated plan to the end, so instructions are clear. + if isinstance(initiator_message, HumanMessage): + if generated_plan: + plan_prompt = PLAN_PROMPT if messages[0] == initiator_message else NEW_PLAN_PROMPT conversation.append( - HumanMessagePromptTemplate.from_template(QUESTION_PROMPT, template_format="mustache").format( - question=human_message.content + HumanMessagePromptTemplate.from_template(plan_prompt, template_format="mustache").format( + plan=generated_plan or "" ) ) - - # Then schemas, but include only last generated schema because it doesn't need more context. - if ai_message and idx + 1 == len(visualization_messages): - conversation.append( - LangchainAssistantMessage(content=ai_message.answer.model_dump_json() if ai_message.answer else "") + conversation.append( + HumanMessagePromptTemplate.from_template(QUESTION_PROMPT, template_format="mustache").format( + question=initiator_message.content ) + ) + # Retries must be added to the end of the conversation. if validation_error_message: conversation.append( HumanMessagePromptTemplate.from_template(FAILOVER_PROMPT, template_format="mustache").format( @@ -193,10 +233,10 @@ class SchemaGeneratorToolsNode(AssistantNode): Used for failover from generation errors. """ - def run(self, state: AssistantState, config: RunnableConfig) -> AssistantState: - intermediate_steps = state.get("intermediate_steps", []) + def run(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState: + intermediate_steps = state.intermediate_steps or [] if not intermediate_steps: - return state + return PartialAssistantState() action, _ = intermediate_steps[-1] prompt = ( @@ -205,9 +245,9 @@ def run(self, state: AssistantState, config: RunnableConfig) -> AssistantState: .content ) - return { - "intermediate_steps": [ + return PartialAssistantState( + intermediate_steps=[ *intermediate_steps[:-1], (action, str(prompt)), ] - } + ) diff --git a/ee/hogai/schema_generator/test/test_nodes.py b/ee/hogai/schema_generator/test/test_nodes.py index 795045af50b56..b44154b93b927 100644 --- a/ee/hogai/schema_generator/test/test_nodes.py +++ b/ee/hogai/schema_generator/test/test_nodes.py @@ -4,10 +4,11 @@ from django.test import override_settings from langchain_core.agents import AgentAction from langchain_core.prompts import ChatPromptTemplate -from langchain_core.runnables import RunnableLambda +from langchain_core.runnables import RunnableConfig, RunnableLambda from ee.hogai.schema_generator.nodes import SchemaGeneratorNode, SchemaGeneratorToolsNode from ee.hogai.schema_generator.utils import SchemaGeneratorOutput +from ee.hogai.utils.types import AssistantState, PartialAssistantState from posthog.schema import ( AssistantMessage, AssistantTrendsQuery, @@ -16,7 +17,7 @@ RouterMessage, VisualizationMessage, ) -from posthog.test.base import APIBaseTest, ClickhouseTestMixin +from posthog.test.base import BaseTest TestSchema = SchemaGeneratorOutput[AssistantTrendsQuery] @@ -26,7 +27,7 @@ class DummyGeneratorNode(SchemaGeneratorNode[AssistantTrendsQuery]): OUTPUT_MODEL = SchemaGeneratorOutput[AssistantTrendsQuery] OUTPUT_SCHEMA = {} - def run(self, state, config): + def run(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState: prompt = ChatPromptTemplate.from_messages( [ ("system", "system_prompt"), @@ -36,8 +37,9 @@ def run(self, state, config): @override_settings(IN_UNIT_TESTING=True) -class TestSchemaGeneratorNode(ClickhouseTestMixin, APIBaseTest): +class TestSchemaGeneratorNode(BaseTest): def setUp(self): + super().setUp() self.schema = AssistantTrendsQuery(series=[]) def test_node_runs(self): @@ -45,23 +47,23 @@ def test_node_runs(self): with patch.object(DummyGeneratorNode, "_model") as generator_model_mock: generator_model_mock.return_value = RunnableLambda(lambda _: TestSchema(query=self.schema).model_dump()) new_state = node.run( - { - "messages": [HumanMessage(content="Text")], - "plan": "Plan", - }, + AssistantState( + messages=[HumanMessage(content="Text", id="0")], + plan="Plan", + start_id="0", + ), {}, ) - self.assertEqual( - new_state, - { - "messages": [VisualizationMessage(answer=self.schema, plan="Plan", done=True)], - "intermediate_steps": None, - }, - ) + self.assertIsNone(new_state.intermediate_steps) + self.assertEqual(len(new_state.messages), 1) + self.assertEqual(new_state.messages[0].type, "ai/viz") + self.assertEqual(new_state.messages[0].answer, self.schema) - def test_agent_reconstructs_conversation(self): + def test_agent_reconstructs_conversation_and_does_not_add_an_empty_plan(self): node = DummyGeneratorNode(self.team) - history = node._construct_messages({"messages": [HumanMessage(content="Text")]}) + history = node._construct_messages( + AssistantState(messages=[HumanMessage(content="Text", id="0")], start_id="0") + ) self.assertEqual(len(history), 2) self.assertEqual(history[0].type, "human") self.assertIn("mapping", history[0].content) @@ -69,7 +71,11 @@ def test_agent_reconstructs_conversation(self): self.assertIn("Answer to this question:", history[1].content) self.assertNotIn("{{question}}", history[1].content) - history = node._construct_messages({"messages": [HumanMessage(content="Text")], "plan": "randomplan"}) + def test_agent_reconstructs_conversation_adds_plan(self): + node = DummyGeneratorNode(self.team) + history = node._construct_messages( + AssistantState(messages=[HumanMessage(content="Text", id="0")], plan="randomplan", start_id="0") + ) self.assertEqual(len(history), 3) self.assertEqual(history[0].type, "human") self.assertIn("mapping", history[0].content) @@ -82,16 +88,18 @@ def test_agent_reconstructs_conversation(self): self.assertNotIn("{{question}}", history[2].content) self.assertIn("Text", history[2].content) + def test_agent_reconstructs_conversation_can_handle_follow_ups(self): node = DummyGeneratorNode(self.team) history = node._construct_messages( - { - "messages": [ - HumanMessage(content="Text"), - VisualizationMessage(answer=self.schema, plan="randomplan"), - HumanMessage(content="Follow Up"), + AssistantState( + messages=[ + HumanMessage(content="Text", id="0"), + VisualizationMessage(answer=self.schema, plan="randomplan", id="1", initiator="0"), + HumanMessage(content="Follow Up", id="2"), ], - "plan": "newrandomplan", - } + plan="newrandomplan", + start_id="2", + ) ) self.assertEqual(len(history), 6) @@ -116,13 +124,41 @@ def test_agent_reconstructs_conversation(self): self.assertNotIn("{{question}}", history[5].content) self.assertIn("Follow Up", history[5].content) - def test_agent_reconstructs_conversation_and_merges_messages(self): + def test_agent_reconstructs_conversation_and_does_not_merge_messages(self): + node = DummyGeneratorNode(self.team) + history = node._construct_messages( + AssistantState( + messages=[HumanMessage(content="Te", id="0"), HumanMessage(content="xt", id="1")], + plan="randomplan", + start_id="1", + ) + ) + self.assertEqual(len(history), 4) + self.assertEqual(history[0].type, "human") + self.assertIn("mapping", history[0].content) + self.assertIn("Te", history[1].content) + self.assertEqual(history[2].type, "human") + self.assertNotIn("{{plan}}", history[2].content) + self.assertIn("randomplan", history[2].content) + self.assertEqual(history[3].type, "human") + self.assertIn("Answer to this question:", history[3].content) + self.assertNotIn("{{question}}", history[3].content) + self.assertEqual(history[3].type, "human") + self.assertIn("xt", history[3].content) + + def test_filters_out_human_in_the_loop_after_initiator(self): node = DummyGeneratorNode(self.team) history = node._construct_messages( - { - "messages": [HumanMessage(content="Te"), HumanMessage(content="xt")], - "plan": "randomplan", - } + AssistantState( + messages=[ + HumanMessage(content="Text", id="0"), + VisualizationMessage(answer=self.schema, plan="randomplan", initiator="0", id="1"), + HumanMessage(content="Follow", id="2"), + HumanMessage(content="Up", id="3"), + ], + plan="newrandomplan", + start_id="0", + ) ) self.assertEqual(len(history), 3) self.assertEqual(history[0].type, "human") @@ -134,104 +170,114 @@ def test_agent_reconstructs_conversation_and_merges_messages(self): self.assertEqual(history[2].type, "human") self.assertIn("Answer to this question:", history[2].content) self.assertNotIn("{{question}}", history[2].content) - self.assertIn("Te\nxt", history[2].content) + self.assertIn("Text", history[2].content) + def test_preserves_human_in_the_loop_before_initiator(self): node = DummyGeneratorNode(self.team) history = node._construct_messages( - { - "messages": [ - HumanMessage(content="Text"), - VisualizationMessage(answer=self.schema, plan="randomplan"), - HumanMessage(content="Follow"), - HumanMessage(content="Up"), + AssistantState( + messages=[ + HumanMessage(content="Question 1", id="0"), + AssistantMessage(content="Loop", id="1"), + HumanMessage(content="Answer", id="2"), + VisualizationMessage(answer=self.schema, plan="randomplan", initiator="0", id="3"), + HumanMessage(content="Question 2", id="4"), ], - "plan": "newrandomplan", - } + plan="newrandomplan", + start_id="4", + ) ) - - self.assertEqual(len(history), 6) + self.assertEqual(len(history), 8) self.assertEqual(history[0].type, "human") self.assertIn("mapping", history[0].content) self.assertEqual(history[1].type, "human") self.assertIn("the plan", history[1].content) self.assertNotIn("{{plan}}", history[1].content) self.assertIn("randomplan", history[1].content) - self.assertEqual(history[2].type, "human") - self.assertIn("Answer to this question:", history[2].content) self.assertNotIn("{{question}}", history[2].content) - self.assertIn("Text", history[2].content) + self.assertIn("Question 1", history[2].content) self.assertEqual(history[3].type, "ai") - self.assertEqual(history[3].content, self.schema.model_dump_json()) + self.assertEqual("Loop", history[3].content) self.assertEqual(history[4].type, "human") - self.assertIn("the new plan", history[4].content) - self.assertNotIn("{{plan}}", history[4].content) - self.assertIn("newrandomplan", history[4].content) - self.assertEqual(history[5].type, "human") - self.assertIn("Answer to this question:", history[5].content) - self.assertNotIn("{{question}}", history[5].content) - self.assertIn("Follow\nUp", history[5].content) + self.assertEqual("Answer", history[4].content) + self.assertEqual(history[5].type, "ai") + self.assertEqual(history[6].type, "human") + self.assertIn("the new plan", history[6].content) + self.assertIn("newrandomplan", history[6].content) + self.assertEqual(history[7].type, "human") + self.assertNotIn("{{question}}", history[7].content) + self.assertIn("Question 2", history[7].content) def test_agent_reconstructs_typical_conversation(self): node = DummyGeneratorNode(self.team) history = node._construct_messages( - { - "messages": [ - HumanMessage(content="Question 1"), - RouterMessage(content="trends"), - VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 1"), - AssistantMessage(content="Summary 1"), - HumanMessage(content="Question 2"), - RouterMessage(content="funnel"), - VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 2"), - AssistantMessage(content="Summary 2"), - HumanMessage(content="Question 3"), - RouterMessage(content="funnel"), + AssistantState( + messages=[ + HumanMessage(content="Question 1", id="0"), + RouterMessage(content="trends", id="1"), + VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 1", initiator="0", id="2"), + AssistantMessage(content="Summary 1", id="3"), + HumanMessage(content="Question 2", id="4"), + RouterMessage(content="funnel", id="5"), + VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 2", initiator="4", id="6"), + AssistantMessage(content="Summary 2", id="7"), + HumanMessage(content="Question 3", id="8"), + RouterMessage(content="funnel", id="9"), ], - "plan": "Plan 3", - } + plan="Plan 3", + start_id="8", + ) ) - self.assertEqual(len(history), 8) + + self.assertEqual(len(history), 10) self.assertEqual(history[0].type, "human") self.assertIn("mapping", history[0].content) self.assertEqual(history[1].type, "human") self.assertIn("Plan 1", history[1].content) self.assertEqual(history[2].type, "human") self.assertIn("Question 1", history[2].content) - self.assertEqual(history[3].type, "human") - self.assertIn("Plan 2", history[3].content) + self.assertEqual(history[3].type, "ai") + self.assertEqual(history[3].content, "Summary 1") self.assertEqual(history[4].type, "human") - self.assertIn("Question 2", history[4].content) - self.assertEqual(history[5].type, "ai") - self.assertEqual(history[6].type, "human") - self.assertIn("Plan 3", history[6].content) - self.assertEqual(history[7].type, "human") - self.assertIn("Question 3", history[7].content) - - def test_prompt(self): + self.assertIn("Plan 2", history[4].content) + self.assertEqual(history[5].type, "human") + self.assertIn("Question 2", history[5].content) + self.assertEqual(history[6].type, "ai") + self.assertEqual(history[6].content, "Summary 2") + self.assertEqual(history[7].type, "ai") + self.assertEqual(history[8].type, "human") + self.assertIn("Plan 3", history[8].content) + self.assertEqual(history[9].type, "human") + self.assertIn("Question 3", history[9].content) + + def test_prompt_messages_merged(self): node = DummyGeneratorNode(self.team) - state = { - "messages": [ - HumanMessage(content="Question 1"), - RouterMessage(content="trends"), - VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 1"), - AssistantMessage(content="Summary 1"), - HumanMessage(content="Question 2"), - RouterMessage(content="funnel"), - VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 2"), - AssistantMessage(content="Summary 2"), - HumanMessage(content="Question 3"), - RouterMessage(content="funnel"), + state = AssistantState( + messages=[ + HumanMessage(content="Question 1", id="0"), + RouterMessage(content="trends", id="1"), + VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 1", initiator="0", id="2"), + AssistantMessage(content="Summary 1", id="3"), + HumanMessage(content="Question 2", id="4"), + RouterMessage(content="funnel", id="5"), + VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 2", initiator="4", id="6"), + AssistantMessage(content="Summary 2", id="7"), + HumanMessage(content="Question 3", id="8"), + RouterMessage(content="funnel", id="9"), ], - "plan": "Plan 3", - } + plan="Plan 3", + start_id="8", + ) with patch.object(DummyGeneratorNode, "_model") as generator_model_mock: def assert_prompt(prompt): - self.assertEqual(len(prompt), 4) + self.assertEqual(len(prompt), 6) self.assertEqual(prompt[0].type, "system") self.assertEqual(prompt[1].type, "human") self.assertEqual(prompt[2].type, "ai") self.assertEqual(prompt[3].type, "human") + self.assertEqual(prompt[4].type, "ai") + self.assertEqual(prompt[5].type, "human") generator_model_mock.return_value = RunnableLambda(assert_prompt) node.run(state, {}) @@ -244,19 +290,17 @@ def test_failover_with_incorrect_schema(self): schema["query"] = [] generator_model_mock.return_value = RunnableLambda(lambda _: json.dumps(schema)) - new_state = node.run({"messages": [HumanMessage(content="Text")]}, {}) - self.assertIn("intermediate_steps", new_state) - self.assertEqual(len(new_state["intermediate_steps"]), 1) + new_state = node.run(AssistantState(messages=[HumanMessage(content="Text")]), {}) + self.assertEqual(len(new_state.intermediate_steps), 1) new_state = node.run( - { - "messages": [HumanMessage(content="Text")], - "intermediate_steps": [(AgentAction(tool="", tool_input="", log="exception"), "exception")], - }, + AssistantState( + messages=[HumanMessage(content="Text")], + intermediate_steps=[(AgentAction(tool="", tool_input="", log="exception"), "exception")], + ), {}, ) - self.assertIn("intermediate_steps", new_state) - self.assertEqual(len(new_state["intermediate_steps"]), 2) + self.assertEqual(len(new_state.intermediate_steps), 2) def test_node_leaves_failover(self): node = DummyGeneratorNode(self.team) @@ -266,25 +310,25 @@ def test_node_leaves_failover(self): return_value=RunnableLambda(lambda _: TestSchema(query=self.schema).model_dump()), ): new_state = node.run( - { - "messages": [HumanMessage(content="Text")], - "intermediate_steps": [(AgentAction(tool="", tool_input="", log="exception"), "exception")], - }, + AssistantState( + messages=[HumanMessage(content="Text")], + intermediate_steps=[(AgentAction(tool="", tool_input="", log="exception"), "exception")], + ), {}, ) - self.assertIsNone(new_state["intermediate_steps"]) + self.assertIsNone(new_state.intermediate_steps) new_state = node.run( - { - "messages": [HumanMessage(content="Text")], - "intermediate_steps": [ + AssistantState( + messages=[HumanMessage(content="Text")], + intermediate_steps=[ (AgentAction(tool="", tool_input="", log="exception"), "exception"), (AgentAction(tool="", tool_input="", log="exception"), "exception"), ], - }, + ), {}, ) - self.assertIsNone(new_state["intermediate_steps"]) + self.assertIsNone(new_state.intermediate_steps) def test_node_leaves_failover_after_second_unsuccessful_attempt(self): node = DummyGeneratorNode(self.team) @@ -295,29 +339,30 @@ def test_node_leaves_failover_after_second_unsuccessful_attempt(self): generator_model_mock.return_value = RunnableLambda(lambda _: json.dumps(schema)) new_state = node.run( - { - "messages": [HumanMessage(content="Text")], - "intermediate_steps": [ + AssistantState( + messages=[HumanMessage(content="Text")], + intermediate_steps=[ (AgentAction(tool="", tool_input="", log="exception"), "exception"), (AgentAction(tool="", tool_input="", log="exception"), "exception"), ], - }, + ), {}, ) - self.assertIsNone(new_state["intermediate_steps"]) - self.assertEqual(len(new_state["messages"]), 1) - self.assertIsInstance(new_state["messages"][0], FailureMessage) + self.assertIsNone(new_state.intermediate_steps) + self.assertEqual(len(new_state.messages), 1) + self.assertIsInstance(new_state.messages[0], FailureMessage) def test_agent_reconstructs_conversation_with_failover(self): action = AgentAction(tool="fix", tool_input="validation error", log="exception") node = DummyGeneratorNode(self.team) history = node._construct_messages( - { - "messages": [HumanMessage(content="Text")], - "plan": "randomplan", - "intermediate_steps": [(action, "uniqexception")], - }, - "uniqexception", + AssistantState( + messages=[HumanMessage(content="Text", id="0")], + plan="randomplan", + intermediate_steps=[(action, "uniqexception")], + start_id="0", + ), + validation_error_message="uniqexception", ) self.assertEqual(len(history), 4) self.assertEqual(history[0].type, "human") @@ -337,14 +382,14 @@ def test_agent_reconstructs_conversation_with_failover(self): def test_agent_reconstructs_conversation_with_failed_messages(self): node = DummyGeneratorNode(self.team) history = node._construct_messages( - { - "messages": [ + AssistantState( + messages=[ HumanMessage(content="Text"), FailureMessage(content="Error"), HumanMessage(content="Text"), ], - "plan": "randomplan", - }, + plan="randomplan", + ), ) self.assertEqual(len(history), 3) self.assertEqual(history[0].type, "human") @@ -360,19 +405,19 @@ def test_agent_reconstructs_conversation_with_failed_messages(self): def test_router(self): node = DummyGeneratorNode(self.team) - state = node.router({"messages": [], "intermediate_steps": None}) + state = node.router(AssistantState(messages=[], intermediate_steps=None)) self.assertEqual(state, "next") state = node.router( - {"messages": [], "intermediate_steps": [(AgentAction(tool="", tool_input="", log=""), None)]} + AssistantState(messages=[], intermediate_steps=[(AgentAction(tool="", tool_input="", log=""), None)]) ) self.assertEqual(state, "tools") -class TestSchemaGeneratorToolsNode(ClickhouseTestMixin, APIBaseTest): +class TestSchemaGeneratorToolsNode(BaseTest): def test_tools_node(self): node = SchemaGeneratorToolsNode(self.team) action = AgentAction(tool="fix", tool_input="validationerror", log="pydanticexception") - state = node.run({"messages": [], "intermediate_steps": [(action, None)]}, {}) - self.assertIsNotNone("validationerror", state["intermediate_steps"][0][1]) - self.assertIn("validationerror", state["intermediate_steps"][0][1]) - self.assertIn("pydanticexception", state["intermediate_steps"][0][1]) + state = node.run(AssistantState(messages=[], intermediate_steps=[(action, None)]), {}) + self.assertIsNotNone("validationerror", state.intermediate_steps[0][1]) + self.assertIn("validationerror", state.intermediate_steps[0][1]) + self.assertIn("pydanticexception", state.intermediate_steps[0][1]) diff --git a/ee/hogai/summarizer/nodes.py b/ee/hogai/summarizer/nodes.py index 8d5e8a406f45e..513246bcc1238 100644 --- a/ee/hogai/summarizer/nodes.py +++ b/ee/hogai/summarizer/nodes.py @@ -1,15 +1,18 @@ import json from time import sleep +from uuid import uuid4 + from django.conf import settings +from django.core.serializers.json import DjangoJSONEncoder from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnableConfig from langchain_openai import ChatOpenAI -from django.core.serializers.json import DjangoJSONEncoder from rest_framework.exceptions import APIException from sentry_sdk import capture_exception -from ee.hogai.summarizer.prompts import SUMMARIZER_SYSTEM_PROMPT, SUMMARIZER_INSTRUCTION_PROMPT -from ee.hogai.utils import AssistantNode, AssistantNodeName, AssistantState +from ee.hogai.summarizer.prompts import SUMMARIZER_INSTRUCTION_PROMPT, SUMMARIZER_SYSTEM_PROMPT +from ee.hogai.utils.nodes import AssistantNode +from ee.hogai.utils.types import AssistantNodeName, AssistantState, PartialAssistantState from posthog.api.services.query import process_query_dict from posthog.clickhouse.client.execute_async import get_query_status from posthog.errors import ExposedCHQueryError @@ -21,8 +24,8 @@ class SummarizerNode(AssistantNode): name = AssistantNodeName.SUMMARIZER - def run(self, state: AssistantState, config: RunnableConfig): - viz_message = state["messages"][-1] + def run(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState: + viz_message = state.messages[-1] if not isinstance(viz_message, VisualizationMessage): raise ValueError("Can only run summarization with a visualization message as the last one in the state") if viz_message.answer is None: @@ -58,10 +61,16 @@ def run(self, state: AssistantState, config: RunnableConfig): err_message = ", ".join(f"{key}: {value}" for key, value in err.detail.items()) elif isinstance(err.detail, list): err_message = ", ".join(map(str, err.detail)) - return {"messages": [FailureMessage(content=f"There was an error running this query: {err_message}")]} + return PartialAssistantState( + messages=[ + FailureMessage(content=f"There was an error running this query: {err_message}", id=str(uuid4())) + ] + ) except Exception as err: capture_exception(err) - return {"messages": [FailureMessage(content="There was an unknown error running this query.")]} + return PartialAssistantState( + messages=[FailureMessage(content="There was an unknown error running this query.", id=str(uuid4()))] + ) summarization_prompt = ChatPromptTemplate(self._construct_messages(state), template_format="mustache") @@ -76,7 +85,7 @@ def run(self, state: AssistantState, config: RunnableConfig): config, ) - return {"messages": [AssistantMessage(content=str(message.content), done=True)]} + return PartialAssistantState(messages=[AssistantMessage(content=str(message.content), id=str(uuid4()))]) @property def _model(self): @@ -85,7 +94,7 @@ def _model(self): def _construct_messages(self, state: AssistantState) -> list[tuple[str, str]]: conversation: list[tuple[str, str]] = [("system", SUMMARIZER_SYSTEM_PROMPT)] - for message in state.get("messages", []): + for message in state.messages: if isinstance(message, HumanMessage): conversation.append(("human", message.content)) elif isinstance(message, AssistantMessage): diff --git a/ee/hogai/summarizer/test/test_nodes.py b/ee/hogai/summarizer/test/test_nodes.py index b38d88275aa19..9c54517717b5f 100644 --- a/ee/hogai/summarizer/test/test_nodes.py +++ b/ee/hogai/summarizer/test/test_nodes.py @@ -1,23 +1,23 @@ from unittest.mock import patch from django.test import override_settings -from langchain_core.runnables import RunnableLambda from langchain_core.messages import ( HumanMessage as LangchainHumanMessage, ) +from langchain_core.runnables import RunnableLambda +from rest_framework.exceptions import ValidationError + from ee.hogai.summarizer.nodes import SummarizerNode from ee.hogai.summarizer.prompts import SUMMARIZER_INSTRUCTION_PROMPT, SUMMARIZER_SYSTEM_PROMPT +from ee.hogai.utils.types import AssistantState +from posthog.api.services.query import process_query_dict from posthog.schema import ( - AssistantMessage, AssistantTrendsEventsNode, AssistantTrendsQuery, - FailureMessage, HumanMessage, VisualizationMessage, ) -from rest_framework.exceptions import ValidationError from posthog.test.base import APIBaseTest, ClickhouseTestMixin -from posthog.api.services.query import process_query_dict @override_settings(IN_UNIT_TESTING=True) @@ -32,28 +32,26 @@ def test_node_runs(self, mock_process_query_dict): lambda _: LangchainHumanMessage(content="The results indicate foobar.") ) new_state = node.run( - { - "messages": [ - HumanMessage(content="Text"), + AssistantState( + messages=[ + HumanMessage(content="Text", id="test"), VisualizationMessage( answer=AssistantTrendsQuery(series=[AssistantTrendsEventsNode()]), plan="Plan", - done=True, + id="test2", + initiator="test", ), ], - "plan": "Plan", - }, + plan="Plan", + start_id="test", + ), {}, ) mock_process_query_dict.assert_called_once() # Query processing started - self.assertEqual( - new_state, - { - "messages": [ - AssistantMessage(content="The results indicate foobar.", done=True), - ], - }, - ) + msg = new_state.messages[0] + self.assertEqual(msg.content, "The results indicate foobar.") + self.assertEqual(msg.type, "ai") + self.assertIsNotNone(msg.id) @patch( "ee.hogai.summarizer.nodes.process_query_dict", @@ -66,28 +64,26 @@ def test_node_handles_internal_error(self, mock_process_query_dict): lambda _: LangchainHumanMessage(content="The results indicate foobar.") ) new_state = node.run( - { - "messages": [ - HumanMessage(content="Text"), + AssistantState( + messages=[ + HumanMessage(content="Text", id="test"), VisualizationMessage( answer=AssistantTrendsQuery(series=[AssistantTrendsEventsNode()]), plan="Plan", - done=True, + id="test2", + initiator="test", ), ], - "plan": "Plan", - }, + plan="Plan", + start_id="test", + ), {}, ) mock_process_query_dict.assert_called_once() # Query processing started - self.assertEqual( - new_state, - { - "messages": [ - FailureMessage(content="There was an unknown error running this query."), - ], - }, - ) + msg = new_state.messages[0] + self.assertEqual(msg.content, "There was an unknown error running this query.") + self.assertEqual(msg.type, "ai/failure") + self.assertIsNotNone(msg.id) @patch( "ee.hogai.summarizer.nodes.process_query_dict", @@ -102,33 +98,29 @@ def test_node_handles_exposed_error(self, mock_process_query_dict): lambda _: LangchainHumanMessage(content="The results indicate foobar.") ) new_state = node.run( - { - "messages": [ - HumanMessage(content="Text"), + AssistantState( + messages=[ + HumanMessage(content="Text", id="test"), VisualizationMessage( answer=AssistantTrendsQuery(series=[AssistantTrendsEventsNode()]), plan="Plan", - done=True, + id="test2", + initiator="test", ), ], - "plan": "Plan", - }, + plan="Plan", + start_id="test", + ), {}, ) mock_process_query_dict.assert_called_once() # Query processing started + msg = new_state.messages[0] self.assertEqual( - new_state, - { - "messages": [ - FailureMessage( - content=( - "There was an error running this query: This query exceeds the capabilities of our picolator. " - "Try de-brolling its flim-flam." - ) - ), - ], - }, + msg.content, + "There was an error running this query: This query exceeds the capabilities of our picolator. Try de-brolling its flim-flam.", ) + self.assertEqual(msg.type, "ai/failure") + self.assertIsNotNone(msg.id) def test_node_requires_a_viz_message_in_state(self): node = SummarizerNode(self.team) @@ -137,12 +129,13 @@ def test_node_requires_a_viz_message_in_state(self): ValueError, "Can only run summarization with a visualization message as the last one in the state" ): node.run( - { - "messages": [ + AssistantState( + messages=[ HumanMessage(content="Text"), ], - "plan": "Plan", - }, + plan="Plan", + start_id="test", + ), {}, ) @@ -151,16 +144,13 @@ def test_node_requires_viz_message_in_state_to_have_query(self): with self.assertRaisesMessage(ValueError, "Did not found query in the visualization message"): node.run( - { - "messages": [ - VisualizationMessage( - answer=None, - plan="Plan", - done=True, - ), + AssistantState( + messages=[ + VisualizationMessage(answer=None, plan="Plan", id="test"), ], - "plan": "Plan", - }, + plan="Plan", + start_id="test", + ), {}, ) @@ -170,16 +160,18 @@ def test_agent_reconstructs_conversation(self): node = SummarizerNode(self.team) history = node._construct_messages( - { - "messages": [ - HumanMessage(content="What's the trends in signups?"), + AssistantState( + messages=[ + HumanMessage(content="What's the trends in signups?", id="test"), VisualizationMessage( answer=AssistantTrendsQuery(series=[AssistantTrendsEventsNode()]), plan="Plan", - done=True, + id="test2", + initiator="test", ), - ] - } + ], + start_id="test", + ) ) self.assertEqual( history, diff --git a/ee/hogai/taxonomy_agent/nodes.py b/ee/hogai/taxonomy_agent/nodes.py index 025058a51eec1..bd26a7a93918f 100644 --- a/ee/hogai/taxonomy_agent/nodes.py +++ b/ee/hogai/taxonomy_agent/nodes.py @@ -1,4 +1,3 @@ -import itertools import xml.etree.ElementTree as ET from abc import ABC from functools import cached_property @@ -7,10 +6,16 @@ from git import Optional from langchain.agents.format_scratchpad import format_log_to_str from langchain_core.agents import AgentAction -from langchain_core.messages import AIMessage as LangchainAssistantMessage, BaseMessage, merge_message_runs +from langchain_core.messages import ( + AIMessage as LangchainAssistantMessage, + BaseMessage, + HumanMessage as LangchainHumanMessage, + merge_message_runs, +) from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate from langchain_core.runnables import RunnableConfig from langchain_openai import ChatOpenAI +from langgraph.errors import NodeInterrupt from pydantic import ValidationError from ee.hogai.taxonomy import CORE_FILTER_DEFINITIONS_BY_GROUP @@ -24,6 +29,7 @@ REACT_FOLLOW_UP_PROMPT, REACT_FORMAT_PROMPT, REACT_FORMAT_REMINDER_PROMPT, + REACT_HUMAN_IN_THE_LOOP_PROMPT, REACT_MALFORMED_JSON_PROMPT, REACT_MISSING_ACTION_CORRECTION_PROMPT, REACT_MISSING_ACTION_PROMPT, @@ -33,13 +39,18 @@ REACT_USER_PROMPT, ) from ee.hogai.taxonomy_agent.toolkit import TaxonomyAgentTool, TaxonomyAgentToolkit -from ee.hogai.utils import AssistantNode, AssistantState, filter_visualization_conversation, remove_line_breaks +from ee.hogai.utils.helpers import filter_messages, remove_line_breaks, slice_messages_to_conversation_start +from ee.hogai.utils.nodes import AssistantNode +from ee.hogai.utils.types import AssistantState, PartialAssistantState from posthog.hogql_queries.ai.team_taxonomy_query_runner import TeamTaxonomyQueryRunner from posthog.hogql_queries.query_runner import ExecutionMode from posthog.models.group_type_mapping import GroupTypeMapping from posthog.schema import ( + AssistantMessage, CachedTeamTaxonomyQueryResponse, + HumanMessage, TeamTaxonomyQuery, + VisualizationMessage, ) @@ -50,8 +61,8 @@ def _run_with_prompt_and_toolkit( prompt: ChatPromptTemplate, toolkit: TaxonomyAgentToolkit, config: Optional[RunnableConfig] = None, - ) -> AssistantState: - intermediate_steps = state.get("intermediate_steps") or [] + ) -> PartialAssistantState: + intermediate_steps = state.intermediate_steps or [] conversation = ( prompt + ChatPromptTemplate.from_messages( @@ -79,6 +90,7 @@ def _run_with_prompt_and_toolkit( "react_format": self._get_react_format_prompt(toolkit), "react_format_reminder": REACT_FORMAT_REMINDER_PROMPT, "react_property_filters": self._get_react_property_filters_prompt(), + "react_human_in_the_loop": REACT_HUMAN_IN_THE_LOOP_PROMPT, "product_description": self._team.project.product_description, "groups": self._team_group_types, "events": self._events_prompt, @@ -108,12 +120,12 @@ def _run_with_prompt_and_toolkit( e.llm_output, ) - return { - "intermediate_steps": [*intermediate_steps, (result, None)], - } + return PartialAssistantState( + intermediate_steps=[*intermediate_steps, (result, None)], + ) def router(self, state: AssistantState): - if state.get("intermediate_steps", []): + if state.intermediate_steps: return "tools" raise ValueError("Invalid state.") @@ -188,33 +200,34 @@ def _construct_messages(self, state: AssistantState) -> list[BaseMessage]: """ Reconstruct the conversation for the agent. On this step we only care about previously asked questions and generated plans. All other messages are filtered out. """ - human_messages, visualization_messages = filter_visualization_conversation(state.get("messages", [])) - - if not human_messages: - return [] - + start_id = state.start_id + filtered_messages = filter_messages(slice_messages_to_conversation_start(state.messages, start_id)) conversation = [] - for idx, messages in enumerate(itertools.zip_longest(human_messages, visualization_messages)): - human_message, viz_message = messages - - if human_message: + for idx, message in enumerate(filtered_messages): + if isinstance(message, HumanMessage): + # Add initial instructions. if idx == 0: conversation.append( HumanMessagePromptTemplate.from_template(REACT_USER_PROMPT, template_format="mustache").format( - question=human_message.content + question=message.content ) ) - else: + # Add follow-up instructions only for the human message that initiated a generation. + elif message.id == start_id: conversation.append( HumanMessagePromptTemplate.from_template( REACT_FOLLOW_UP_PROMPT, template_format="mustache", - ).format(feedback=human_message.content) + ).format(feedback=message.content) ) - - if viz_message: - conversation.append(LangchainAssistantMessage(content=viz_message.plan or "")) + # Everything else leave as is. + else: + conversation.append(LangchainHumanMessage(content=message.content)) + elif isinstance(message, VisualizationMessage): + conversation.append(LangchainAssistantMessage(content=message.plan or "")) + elif isinstance(message, AssistantMessage): + conversation.append(LangchainAssistantMessage(content=message.content)) return conversation @@ -230,26 +243,37 @@ def _get_agent_scratchpad(self, scratchpad: list[tuple[AgentAction, str | None]] class TaxonomyAgentPlannerToolsNode(AssistantNode, ABC): def _run_with_toolkit( self, state: AssistantState, toolkit: TaxonomyAgentToolkit, config: Optional[RunnableConfig] = None - ) -> AssistantState: - intermediate_steps = state.get("intermediate_steps") or [] - action, _ = intermediate_steps[-1] + ) -> PartialAssistantState: + intermediate_steps = state.intermediate_steps or [] + action, observation = intermediate_steps[-1] try: input = TaxonomyAgentTool.model_validate({"name": action.tool, "arguments": action.tool_input}).root except ValidationError as e: - observation = ( + observation = str( ChatPromptTemplate.from_template(REACT_PYDANTIC_VALIDATION_EXCEPTION_PROMPT, template_format="mustache") .format_messages(exception=e.errors(include_url=False))[0] .content ) - return {"intermediate_steps": [*intermediate_steps[:-1], (action, str(observation))]} + return PartialAssistantState( + intermediate_steps=[*intermediate_steps[:-1], (action, str(observation))], + ) # The plan has been found. Move to the generation. if input.name == "final_answer": - return { - "plan": input.arguments, - "intermediate_steps": None, - } + return PartialAssistantState( + plan=input.arguments, + intermediate_steps=[], + ) + if input.name == "ask_user_for_help": + # The agent has requested help, so we interrupt the graph. + if not observation: + raise NodeInterrupt(input.arguments) + + # Feedback was provided. + return PartialAssistantState( + intermediate_steps=[*intermediate_steps[:-1], (action, observation)], + ) output = "" if input.name == "retrieve_event_properties": @@ -263,9 +287,11 @@ def _run_with_toolkit( else: output = toolkit.handle_incorrect_response(input.arguments) - return {"intermediate_steps": [*intermediate_steps[:-1], (action, output)]} + return PartialAssistantState( + intermediate_steps=[*intermediate_steps[:-1], (action, output)], + ) def router(self, state: AssistantState): - if state.get("plan") is not None: + if state.plan is not None: return "plan_found" return "continue" diff --git a/ee/hogai/taxonomy_agent/prompts.py b/ee/hogai/taxonomy_agent/prompts.py index f63a7dfe15455..c9d409bcdf103 100644 --- a/ee/hogai/taxonomy_agent/prompts.py +++ b/ee/hogai/taxonomy_agent/prompts.py @@ -81,6 +81,15 @@ """.strip() +REACT_HUMAN_IN_THE_LOOP_PROMPT = """ + +Ask the user for clarification if: +- The user's question is ambiguous. +- You can't find matching events or properties. +- You're unable to build a plan that effectively answers the user's question. + +""".strip() + REACT_FORMAT_REMINDER_PROMPT = """ Begin! Reminder that you must ALWAYS respond with a valid JSON blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB``` then Observation. """.strip() diff --git a/ee/hogai/taxonomy_agent/test/test_nodes.py b/ee/hogai/taxonomy_agent/test/test_nodes.py index 40127c19370b6..cb25331664331 100644 --- a/ee/hogai/taxonomy_agent/test/test_nodes.py +++ b/ee/hogai/taxonomy_agent/test/test_nodes.py @@ -11,7 +11,7 @@ TaxonomyAgentPlannerToolsNode, ) from ee.hogai.taxonomy_agent.toolkit import TaxonomyAgentToolkit, ToolkitTool -from ee.hogai.utils import AssistantState +from ee.hogai.utils.types import AssistantState, PartialAssistantState from posthog.models import GroupTypeMapping from posthog.schema import ( AssistantMessage, @@ -37,7 +37,7 @@ def setUp(self): def _get_node(self): class Node(TaxonomyAgentPlannerNode): - def run(self, state: AssistantState, config: RunnableConfig) -> AssistantState: + def run(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState: prompt: ChatPromptTemplate = ChatPromptTemplate.from_messages([("user", "test")]) toolkit = DummyToolkit(self._team) return super()._run_with_prompt_and_toolkit(state, prompt, toolkit, config=config) @@ -46,19 +46,20 @@ def run(self, state: AssistantState, config: RunnableConfig) -> AssistantState: def test_agent_reconstructs_conversation(self): node = self._get_node() - history = node._construct_messages({"messages": [HumanMessage(content="Text")]}) + history = node._construct_messages(AssistantState(messages=[HumanMessage(content="Text")])) self.assertEqual(len(history), 1) self.assertEqual(history[0].type, "human") self.assertIn("Text", history[0].content) self.assertNotIn(f"{{question}}", history[0].content) history = node._construct_messages( - { - "messages": [ - HumanMessage(content="Text"), - VisualizationMessage(answer=self.schema, plan="randomplan"), - ] - } + AssistantState( + messages=[ + HumanMessage(content="Text", id="0"), + VisualizationMessage(answer=self.schema, plan="randomplan", id="1", initiator="0"), + ], + start_id="1", + ) ) self.assertEqual(len(history), 2) self.assertEqual(history[0].type, "human") @@ -68,13 +69,14 @@ def test_agent_reconstructs_conversation(self): self.assertEqual(history[1].content, "randomplan") history = node._construct_messages( - { - "messages": [ - HumanMessage(content="Text"), - VisualizationMessage(answer=self.schema, plan="randomplan"), - HumanMessage(content="Text"), - ] - } + AssistantState( + messages=[ + HumanMessage(content="Text", id="0"), + VisualizationMessage(answer=self.schema, plan="randomplan", id="1", initiator="0"), + HumanMessage(content="Text", id="2"), + ], + start_id="2", + ) ) self.assertEqual(len(history), 3) self.assertEqual(history[0].type, "human") @@ -89,12 +91,14 @@ def test_agent_reconstructs_conversation(self): def test_agent_reconstructs_conversation_and_omits_unknown_messages(self): node = self._get_node() history = node._construct_messages( - { - "messages": [ - HumanMessage(content="Text"), - AssistantMessage(content="test"), - ] - } + AssistantState( + messages=[ + HumanMessage(content="Text", id="0"), + RouterMessage(content="trends", id="1"), + AssistantMessage(content="test", id="2"), + ], + start_id="0", + ) ) self.assertEqual(len(history), 1) self.assertEqual(history[0].type, "human") @@ -104,13 +108,13 @@ def test_agent_reconstructs_conversation_and_omits_unknown_messages(self): def test_agent_reconstructs_conversation_with_failures(self): node = self._get_node() history = node._construct_messages( - { - "messages": [ + AssistantState( + messages=[ HumanMessage(content="Text"), FailureMessage(content="Error"), HumanMessage(content="Text"), - ] - } + ], + ) ) self.assertEqual(len(history), 1) self.assertEqual(history[0].type, "human") @@ -120,32 +124,60 @@ def test_agent_reconstructs_conversation_with_failures(self): def test_agent_reconstructs_typical_conversation(self): node = self._get_node() history = node._construct_messages( - { - "messages": [ - HumanMessage(content="Question 1"), - RouterMessage(content="trends"), - VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 1"), - AssistantMessage(content="Summary 1"), - HumanMessage(content="Question 2"), - RouterMessage(content="funnel"), - VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 2"), - AssistantMessage(content="Summary 2"), - HumanMessage(content="Question 3"), - RouterMessage(content="funnel"), - ] - } + AssistantState( + messages=[ + HumanMessage(content="Question 1", id="0"), + RouterMessage(content="trends", id="1"), + VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 1", id="2", initiator="0"), + AssistantMessage(content="Summary 1", id="3"), + HumanMessage(content="Question 2", id="4"), + RouterMessage(content="funnel", id="5"), + AssistantMessage(content="Loop 1", id="6"), + HumanMessage(content="Loop Answer 1", id="7"), + VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 2", id="8", initiator="4"), + AssistantMessage(content="Summary 2", id="9"), + HumanMessage(content="Question 3", id="10"), + RouterMessage(content="funnel", id="11"), + ], + start_id="10", + ) ) - self.assertEqual(len(history), 5) + self.assertEqual(len(history), 9) self.assertEqual(history[0].type, "human") self.assertIn("Question 1", history[0].content) self.assertEqual(history[1].type, "ai") self.assertEqual(history[1].content, "Plan 1") - self.assertEqual(history[2].type, "human") - self.assertIn("Question 2", history[2].content) - self.assertEqual(history[3].type, "ai") - self.assertEqual(history[3].content, "Plan 2") - self.assertEqual(history[4].type, "human") - self.assertIn("Question 3", history[4].content) + self.assertEqual(history[2].type, "ai") + self.assertEqual(history[2].content, "Summary 1") + self.assertEqual(history[3].type, "human") + self.assertIn("Question 2", history[3].content) + self.assertEqual(history[4].type, "ai") + self.assertEqual(history[4].content, "Loop 1") + self.assertEqual(history[5].type, "human") + self.assertEqual(history[5].content, "Loop Answer 1") + self.assertEqual(history[6].content, "Plan 2") + self.assertEqual(history[6].type, "ai") + self.assertEqual(history[7].type, "ai") + self.assertEqual(history[7].content, "Summary 2") + self.assertEqual(history[8].type, "human") + self.assertIn("Question 3", history[8].content) + + def test_agent_reconstructs_conversation_without_messages_after_parent(self): + node = self._get_node() + history = node._construct_messages( + AssistantState( + messages=[ + HumanMessage(content="Question 1", id="0"), + RouterMessage(content="trends", id="1"), + AssistantMessage(content="Loop 1", id="2"), + HumanMessage(content="Loop Answer 1", id="3"), + ], + start_id="0", + ) + ) + self.assertEqual(len(history), 1) + self.assertEqual(history[0].type, "human") + self.assertIn("Question 1", history[0].content) def test_agent_filters_out_low_count_events(self): _create_person(distinct_ids=["test"], team=self.team) @@ -182,9 +214,9 @@ def test_agent_handles_output_without_action_block(self): return_value=RunnableLambda(lambda _: LangchainAIMessage(content="I don't want to output an action.")), ): node = self._get_node() - state_update = node.run({"messages": [HumanMessage(content="Question")]}, {}) - self.assertEqual(len(state_update["intermediate_steps"]), 1) - action, obs = state_update["intermediate_steps"][0] + state_update = node.run(AssistantState(messages=[HumanMessage(content="Question")]), {}) + self.assertEqual(len(state_update.intermediate_steps), 1) + action, obs = state_update.intermediate_steps[0] self.assertIsNone(obs) self.assertIn("I don't want to output an action.", action.log) self.assertIn("Action:", action.log) @@ -196,9 +228,9 @@ def test_agent_handles_output_with_malformed_json(self): return_value=RunnableLambda(lambda _: LangchainAIMessage(content="Thought.\nAction: abc")), ): node = self._get_node() - state_update = node.run({"messages": [HumanMessage(content="Question")]}, {}) - self.assertEqual(len(state_update["intermediate_steps"]), 1) - action, obs = state_update["intermediate_steps"][0] + state_update = node.run(AssistantState(messages=[HumanMessage(content="Question")]), {}) + self.assertEqual(len(state_update.intermediate_steps), 1) + action, obs = state_update.intermediate_steps[0] self.assertIsNone(obs) self.assertIn("Thought.\nAction: abc", action.log) self.assertIn("action", action.tool_input) @@ -232,34 +264,34 @@ def test_property_filters_prompt(self): class TestTaxonomyAgentPlannerToolsNode(ClickhouseTestMixin, APIBaseTest): def _get_node(self): class Node(TaxonomyAgentPlannerToolsNode): - def run(self, state: AssistantState, config: RunnableConfig) -> AssistantState: + def run(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState: toolkit = DummyToolkit(self._team) return super()._run_with_toolkit(state, toolkit, config=config) return Node(self.team) def test_node_handles_action_name_validation_error(self): - state = { - "intermediate_steps": [(AgentAction(tool="does not exist", tool_input="input", log="log"), "test")], - "messages": [], - } + state = AssistantState( + intermediate_steps=[(AgentAction(tool="does not exist", tool_input="input", log="log"), "test")], + messages=[], + ) node = self._get_node() state_update = node.run(state, {}) - self.assertEqual(len(state_update["intermediate_steps"]), 1) - action, observation = state_update["intermediate_steps"][0] + self.assertEqual(len(state_update.intermediate_steps), 1) + action, observation = state_update.intermediate_steps[0] self.assertIsNotNone(observation) self.assertIn("", observation) def test_node_handles_action_input_validation_error(self): - state = { - "intermediate_steps": [ + state = AssistantState( + intermediate_steps=[ (AgentAction(tool="retrieve_entity_property_values", tool_input="input", log="log"), "test") ], - "messages": [], - } + messages=[], + ) node = self._get_node() state_update = node.run(state, {}) - self.assertEqual(len(state_update["intermediate_steps"]), 1) - action, observation = state_update["intermediate_steps"][0] + self.assertEqual(len(state_update.intermediate_steps), 1) + action, observation = state_update.intermediate_steps[0] self.assertIsNotNone(observation) self.assertIn("", observation) diff --git a/ee/hogai/taxonomy_agent/toolkit.py b/ee/hogai/taxonomy_agent/toolkit.py index dc8a0e092c2e6..d05b6f0c933ef 100644 --- a/ee/hogai/taxonomy_agent/toolkit.py +++ b/ee/hogai/taxonomy_agent/toolkit.py @@ -55,6 +55,7 @@ class SingleArgumentTaxonomyAgentTool(BaseModel): "retrieve_event_properties", "final_answer", "handle_incorrect_response", + "ask_user_for_help", ] arguments: str @@ -145,6 +146,16 @@ def _default_tools(self) -> list[ToolkitTool]: property_name: The name of the property that you want to retrieve values for. """, }, + { + "name": "ask_user_for_help", + "signature": "(question: str)", + "description": """ + Use this tool to ask a question to the user. Your question must be concise and clear. + + Args: + question: The question you want to ask. + """, + }, ] def render_text_description(self) -> str: diff --git a/ee/hogai/test/test_assistant.py b/ee/hogai/test/test_assistant.py index b6cd65bd4ec12..6d0bb8807d629 100644 --- a/ee/hogai/test/test_assistant.py +++ b/ee/hogai/test/test_assistant.py @@ -1,31 +1,63 @@ import json -from typing import Any +from typing import Any, Optional, cast from unittest.mock import patch -from uuid import uuid4 -from ee.hogai.utils import Conversation -from posthog.schema import AssistantMessage, HumanMessage -from ..assistant import Assistant + +from langchain_core import messages +from langchain_core.agents import AgentAction +from langchain_core.runnables import RunnableConfig, RunnableLambda from langgraph.graph.state import CompiledStateGraph +from langgraph.types import StateSnapshot +from pydantic import BaseModel + +from ee.models.assistant import Conversation +from posthog.schema import AssistantMessage, HumanMessage, ReasoningMessage +from posthog.test.base import NonAtomicBaseTest + +from ..assistant import Assistant from ..graph import AssistantGraph, AssistantNodeName -from posthog.test.base import BaseTest -from langchain_core.agents import AgentAction -class TestAssistant(BaseTest): - def _run_assistant_graph(self, test_graph: CompiledStateGraph) -> list[tuple[str, Any]]: +class TestAssistant(NonAtomicBaseTest): + CLASS_DATA_LEVEL_SETUP = False + + def setUp(self): + super().setUp() + self.conversation = Conversation.objects.create(team=self.team, user=self.user) + + def _run_assistant_graph( + self, + test_graph: Optional[CompiledStateGraph] = None, + message: Optional[str] = "Hello", + conversation: Optional[Conversation] = None, + is_new_conversation: bool = False, + ) -> list[tuple[str, Any]]: # Create assistant instance with our test graph assistant = Assistant( - team=self.team, - conversation=Conversation(messages=[HumanMessage(content="Hello")], session_id=str(uuid4())), + self.team, + conversation or self.conversation, + HumanMessage(content=message), + self.user, + is_new_conversation=is_new_conversation, ) - assistant._graph = test_graph + if test_graph: + assistant._graph = test_graph # Capture and parse output of assistant.stream() output: list[tuple[str, Any]] = [] for message in assistant.stream(): - event_line, data_line, *_ = message.split("\n") + event_line, data_line, *_ = cast(str, message).split("\n") output.append((event_line.removeprefix("event: "), json.loads(data_line.removeprefix("data: ")))) return output + def assertConversationEqual(self, output: list[tuple[str, Any]], expected_output: list[tuple[str, Any]]): + for i, ((output_msg_type, output_msg), (expected_msg_type, expected_msg)) in enumerate( + zip(output, expected_output) + ): + self.assertEqual(output_msg_type, expected_msg_type, f"Message type mismatch at index {i}") + msg_dict = ( + expected_msg.model_dump(exclude_none=True) if isinstance(expected_msg, BaseModel) else expected_msg + ) + self.assertDictContainsSubset(msg_dict, output_msg, f"Message content mismatch at index {i}") + @patch( "ee.hogai.trends.nodes.TrendsPlannerNode.run", return_value={"intermediate_steps": [(AgentAction(tool="final_answer", tool_input="", log=""), None)]}, @@ -39,19 +71,22 @@ def test_reasoning_messages_added(self, _mock_summarizer_run, _mock_funnel_plann .add_edge(AssistantNodeName.START, AssistantNodeName.TRENDS_PLANNER) .add_trends_planner(AssistantNodeName.SUMMARIZER) .add_summarizer(AssistantNodeName.END) - .compile() + .compile(), + conversation=self.conversation, ) # Assert that ReasoningMessages are added - assert output == [ - ("status", {"type": "ack"}), + expected_output = [ + ( + "message", + HumanMessage(content="Hello").model_dump(exclude_none=True), + ), ( "message", { "type": "ai/reasoning", "content": "Picking relevant events and properties", # For TrendsPlannerNode "substeps": [], - "done": True, }, ), ( @@ -60,7 +95,6 @@ def test_reasoning_messages_added(self, _mock_summarizer_run, _mock_funnel_plann "type": "ai/reasoning", "content": "Picking relevant events and properties", # For TrendsPlannerToolsNode "substeps": [], - "done": True, }, ), ( @@ -71,6 +105,7 @@ def test_reasoning_messages_added(self, _mock_summarizer_run, _mock_funnel_plann }, ), ] + self.assertConversationEqual(output, expected_output) @patch( "ee.hogai.trends.nodes.TrendsPlannerNode.run", @@ -105,19 +140,22 @@ def test_reasoning_messages_with_substeps_added(self, _mock_funnel_planner_run): AssistantGraph(self.team) .add_edge(AssistantNodeName.START, AssistantNodeName.TRENDS_PLANNER) .add_trends_planner(AssistantNodeName.END) - .compile() + .compile(), + conversation=self.conversation, ) # Assert that ReasoningMessages are added - assert output == [ - ("status", {"type": "ack"}), + expected_output = [ + ( + "message", + HumanMessage(content="Hello").model_dump(exclude_none=True), + ), ( "message", { "type": "ai/reasoning", "content": "Picking relevant events and properties", # For TrendsPlannerNode "substeps": [], - "done": True, }, ), ( @@ -131,7 +169,153 @@ def test_reasoning_messages_with_substeps_added(self, _mock_funnel_planner_run): "Analyzing `currency` event's property `purchase`", "Analyzing person property `country_of_birth`", ], - "done": True, }, ), ] + self.assertConversationEqual(output, expected_output) + + def _test_human_in_the_loop(self, graph: CompiledStateGraph): + with patch("ee.hogai.taxonomy_agent.nodes.TaxonomyAgentPlannerNode._model") as mock: + config: RunnableConfig = { + "configurable": { + "thread_id": self.conversation.id, + } + } + + # Interrupt the graph + message = """ + Thought: Let's ask for help. + Action: + ``` + { + "action": "ask_user_for_help", + "action_input": "Need help with this query" + } + ``` + """ + mock.return_value = RunnableLambda(lambda _: messages.AIMessage(content=message)) + output = self._run_assistant_graph(graph, conversation=self.conversation) + expected_output = [ + ("message", HumanMessage(content="Hello")), + ("message", ReasoningMessage(content="Picking relevant events and properties", substeps=[])), + ("message", ReasoningMessage(content="Picking relevant events and properties", substeps=[])), + ("message", AssistantMessage(content="Need help with this query")), + ] + self.assertConversationEqual(output, expected_output) + snapshot: StateSnapshot = graph.get_state(config) + self.assertTrue(snapshot.next) + self.assertIn("intermediate_steps", snapshot.values) + + # Resume the graph from the interruption point. + message = """ + Thought: Finish. + Action: + ``` + { + "action": "final_answer", + "action_input": "Plan" + } + ``` + """ + mock.return_value = RunnableLambda(lambda _: messages.AIMessage(content=message)) + output = self._run_assistant_graph(graph, conversation=self.conversation, message="It's straightforward") + expected_output = [ + ("message", HumanMessage(content="It's straightforward")), + ("message", ReasoningMessage(content="Picking relevant events and properties", substeps=[])), + ("message", ReasoningMessage(content="Picking relevant events and properties", substeps=[])), + ] + self.assertConversationEqual(output, expected_output) + snapshot: StateSnapshot = graph.get_state(config) + self.assertFalse(snapshot.next) + self.assertEqual(snapshot.values.get("intermediate_steps"), []) + self.assertEqual(snapshot.values["plan"], "Plan") + + def test_trends_interrupt_when_asking_for_help(self): + graph = ( + AssistantGraph(self.team) + .add_edge(AssistantNodeName.START, AssistantNodeName.TRENDS_PLANNER) + .add_trends_planner(AssistantNodeName.END) + .compile() + ) + self._test_human_in_the_loop(graph) + + def test_funnels_interrupt_when_asking_for_help(self): + graph = ( + AssistantGraph(self.team) + .add_edge(AssistantNodeName.START, AssistantNodeName.FUNNEL_PLANNER) + .add_funnel_planner(AssistantNodeName.END) + .compile() + ) + self._test_human_in_the_loop(graph) + + def test_intermediate_steps_are_updated_after_feedback(self): + with patch("ee.hogai.taxonomy_agent.nodes.TaxonomyAgentPlannerNode._model") as mock: + graph = ( + AssistantGraph(self.team) + .add_edge(AssistantNodeName.START, AssistantNodeName.TRENDS_PLANNER) + .add_trends_planner(AssistantNodeName.END) + .compile() + ) + config: RunnableConfig = { + "configurable": { + "thread_id": self.conversation.id, + } + } + + # Interrupt the graph + message = """ + Thought: Let's ask for help. + Action: + ``` + { + "action": "ask_user_for_help", + "action_input": "Need help with this query" + } + ``` + """ + mock.return_value = RunnableLambda(lambda _: messages.AIMessage(content=message)) + self._run_assistant_graph(graph, conversation=self.conversation) + snapshot: StateSnapshot = graph.get_state(config) + self.assertTrue(snapshot.next) + self.assertIn("intermediate_steps", snapshot.values) + self.assertEqual(len(snapshot.values["intermediate_steps"]), 1) + action, observation = snapshot.values["intermediate_steps"][0] + self.assertEqual(action.tool, "ask_user_for_help") + self.assertIsNone(observation) + + self._run_assistant_graph(graph, conversation=self.conversation, message="It's straightforward") + snapshot: StateSnapshot = graph.get_state(config) + self.assertTrue(snapshot.next) + self.assertIn("intermediate_steps", snapshot.values) + self.assertEqual(len(snapshot.values["intermediate_steps"]), 2) + action, observation = snapshot.values["intermediate_steps"][0] + self.assertEqual(action.tool, "ask_user_for_help") + self.assertEqual(observation, "It's straightforward") + action, observation = snapshot.values["intermediate_steps"][1] + self.assertEqual(action.tool, "ask_user_for_help") + self.assertIsNone(observation) + + def test_new_conversation_handles_serialized_conversation(self): + graph = ( + AssistantGraph(self.team) + .add_node(AssistantNodeName.ROUTER, lambda _: {"messages": [AssistantMessage(content="Hello")]}) + .add_edge(AssistantNodeName.START, AssistantNodeName.ROUTER) + .add_edge(AssistantNodeName.ROUTER, AssistantNodeName.END) + .compile() + ) + output = self._run_assistant_graph( + graph, + conversation=self.conversation, + is_new_conversation=True, + ) + expected_output = [ + ("conversation", {"id": str(self.conversation.id)}), + ] + self.assertConversationEqual(output[:1], expected_output) + + output = self._run_assistant_graph( + graph, + conversation=self.conversation, + is_new_conversation=False, + ) + self.assertNotEqual(output[0][0], "conversation") diff --git a/ee/hogai/test/test_utils.py b/ee/hogai/test/test_utils.py index 42e54d058c556..8c32471c88508 100644 --- a/ee/hogai/test/test_utils.py +++ b/ee/hogai/test/test_utils.py @@ -1,6 +1,4 @@ -from langchain_core.messages import HumanMessage as LangchainHumanMessage - -from ee.hogai.utils import filter_visualization_conversation, merge_human_messages +from ee.hogai.utils.helpers import filter_messages from posthog.schema import ( AssistantMessage, AssistantTrendsQuery, @@ -13,40 +11,29 @@ class TestTrendsUtils(BaseTest): - def test_merge_human_messages(self): - res = merge_human_messages( - [ - LangchainHumanMessage(content="Text"), - LangchainHumanMessage(content="Text"), - LangchainHumanMessage(content="Te"), - LangchainHumanMessage(content="xt"), - ] - ) - self.assertEqual(len(res), 1) - self.assertEqual(res, [LangchainHumanMessage(content="Text\nTe\nxt")]) - - def test_filter_trends_conversation(self): - human_messages, visualization_messages = filter_visualization_conversation( + def test_filters_and_merges_human_messages(self): + conversation = [ + HumanMessage(content="Text"), + FailureMessage(content="Error"), + HumanMessage(content="Text"), + VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="plan"), + HumanMessage(content="Text2"), + VisualizationMessage(answer=None, plan="plan"), + ] + messages = filter_messages(conversation) + self.assertEqual(len(messages), 4) + self.assertEqual( [ - HumanMessage(content="Text"), - FailureMessage(content="Error"), - HumanMessage(content="Text"), + HumanMessage(content="Text\nText"), VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="plan"), HumanMessage(content="Text2"), VisualizationMessage(answer=None, plan="plan"), - ] - ) - self.assertEqual(len(human_messages), 2) - self.assertEqual(len(visualization_messages), 1) - self.assertEqual( - human_messages, [LangchainHumanMessage(content="Text"), LangchainHumanMessage(content="Text2")] - ) - self.assertEqual( - visualization_messages, [VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="plan")] + ], + messages, ) def test_filters_typical_conversation(self): - human_messages, visualization_messages = filter_visualization_conversation( + messages = filter_messages( [ HumanMessage(content="Question 1"), RouterMessage(content="trends"), @@ -58,15 +45,30 @@ def test_filters_typical_conversation(self): AssistantMessage(content="Summary 2"), ] ) - self.assertEqual(len(human_messages), 2) - self.assertEqual(len(visualization_messages), 2) - self.assertEqual( - human_messages, [LangchainHumanMessage(content="Question 1"), LangchainHumanMessage(content="Question 2")] - ) + self.assertEqual(len(messages), 6) self.assertEqual( - visualization_messages, + messages, [ + HumanMessage(content="Question 1"), VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 1"), + AssistantMessage(content="Summary 1"), + HumanMessage(content="Question 2"), VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 2"), + AssistantMessage(content="Summary 2"), + ], + ) + + def test_joins_human_messages(self): + messages = filter_messages( + [ + HumanMessage(content="Question 1"), + HumanMessage(content="Question 2"), + ] + ) + self.assertEqual(len(messages), 1) + self.assertEqual( + messages, + [ + HumanMessage(content="Question 1\nQuestion 2"), ], ) diff --git a/ee/hogai/trends/nodes.py b/ee/hogai/trends/nodes.py index b6b33cf6d8354..e430b4036e043 100644 --- a/ee/hogai/trends/nodes.py +++ b/ee/hogai/trends/nodes.py @@ -6,12 +6,12 @@ from ee.hogai.taxonomy_agent.nodes import TaxonomyAgentPlannerNode, TaxonomyAgentPlannerToolsNode from ee.hogai.trends.prompts import REACT_SYSTEM_PROMPT, TRENDS_SYSTEM_PROMPT from ee.hogai.trends.toolkit import TRENDS_SCHEMA, TrendsTaxonomyAgentToolkit -from ee.hogai.utils import AssistantState +from ee.hogai.utils.types import AssistantState, PartialAssistantState from posthog.schema import AssistantTrendsQuery class TrendsPlannerNode(TaxonomyAgentPlannerNode): - def run(self, state: AssistantState, config: RunnableConfig) -> AssistantState: + def run(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState: toolkit = TrendsTaxonomyAgentToolkit(self._team) prompt = ChatPromptTemplate.from_messages( [ @@ -23,7 +23,7 @@ def run(self, state: AssistantState, config: RunnableConfig) -> AssistantState: class TrendsPlannerToolsNode(TaxonomyAgentPlannerToolsNode): - def run(self, state: AssistantState, config: RunnableConfig) -> AssistantState: + def run(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState: toolkit = TrendsTaxonomyAgentToolkit(self._team) return super()._run_with_toolkit(state, toolkit, config=config) @@ -36,7 +36,7 @@ class TrendsGeneratorNode(SchemaGeneratorNode[AssistantTrendsQuery]): OUTPUT_MODEL = TrendsSchemaGeneratorOutput OUTPUT_SCHEMA = TRENDS_SCHEMA - def run(self, state: AssistantState, config: RunnableConfig) -> AssistantState: + def run(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState: prompt = ChatPromptTemplate.from_messages( [ ("system", TRENDS_SYSTEM_PROMPT), diff --git a/ee/hogai/trends/prompts.py b/ee/hogai/trends/prompts.py index 2ac9496480cdd..dcc1daeaa5a00 100644 --- a/ee/hogai/trends/prompts.py +++ b/ee/hogai/trends/prompts.py @@ -12,6 +12,8 @@ {{react_format}} +{{react_human_in_the_loop}} + Below you will find information on how to correctly discover the taxonomy of the user's data. diff --git a/ee/hogai/trends/test/test_nodes.py b/ee/hogai/trends/test/test_nodes.py index 44973b3195377..369ce8bc9b292 100644 --- a/ee/hogai/trends/test/test_nodes.py +++ b/ee/hogai/trends/test/test_nodes.py @@ -4,6 +4,7 @@ from langchain_core.runnables import RunnableLambda from ee.hogai.trends.nodes import TrendsGeneratorNode, TrendsSchemaGeneratorOutput +from ee.hogai.utils.types import AssistantState, PartialAssistantState from posthog.schema import ( AssistantTrendsQuery, HumanMessage, @@ -17,6 +18,7 @@ class TestTrendsGeneratorNode(ClickhouseTestMixin, APIBaseTest): maxDiff = None def setUp(self): + super().setUp() self.schema = AssistantTrendsQuery(series=[]) def test_node_runs(self): @@ -26,16 +28,16 @@ def test_node_runs(self): lambda _: TrendsSchemaGeneratorOutput(query=self.schema).model_dump() ) new_state = node.run( - { - "messages": [HumanMessage(content="Text")], - "plan": "Plan", - }, + AssistantState( + messages=[HumanMessage(content="Text")], + plan="Plan", + ), {}, ) self.assertEqual( new_state, - { - "messages": [VisualizationMessage(answer=self.schema, plan="Plan", done=True)], - "intermediate_steps": None, - }, + PartialAssistantState( + messages=[VisualizationMessage(answer=self.schema, plan="Plan", id=new_state.messages[0].id)], + intermediate_steps=None, + ), ) diff --git a/ee/hogai/trends/toolkit.py b/ee/hogai/trends/toolkit.py index d69830d2f2cd6..5fd7a35f0f18a 100644 --- a/ee/hogai/trends/toolkit.py +++ b/ee/hogai/trends/toolkit.py @@ -1,8 +1,6 @@ from ee.hogai.taxonomy_agent.toolkit import TaxonomyAgentToolkit, ToolkitTool -from ee.hogai.utils import dereference_schema -from posthog.schema import ( - AssistantTrendsQuery, -) +from ee.hogai.utils.helpers import dereference_schema +from posthog.schema import AssistantTrendsQuery class TrendsTaxonomyAgentToolkit(TaxonomyAgentToolkit): diff --git a/ee/hogai/utils.py b/ee/hogai/utils.py deleted file mode 100644 index 559a369df83c8..0000000000000 --- a/ee/hogai/utils.py +++ /dev/null @@ -1,117 +0,0 @@ -import operator -from abc import ABC, abstractmethod -from collections.abc import Sequence -from enum import StrEnum -from typing import Annotated, Optional, TypedDict, Union - -from jsonref import replace_refs -from langchain_core.agents import AgentAction -from langchain_core.messages import ( - HumanMessage as LangchainHumanMessage, - merge_message_runs, -) -from langchain_core.runnables import RunnableConfig -from langgraph.graph import END, START -from pydantic import BaseModel, Field - -from posthog.models.team.team import Team -from posthog.schema import ( - AssistantMessage, - FailureMessage, - HumanMessage, - ReasoningMessage, - RootAssistantMessage, - RouterMessage, - VisualizationMessage, -) - -AssistantMessageUnion = Union[ - AssistantMessage, HumanMessage, VisualizationMessage, FailureMessage, RouterMessage, ReasoningMessage -] - - -class Conversation(BaseModel): - messages: list[RootAssistantMessage] = Field(..., min_length=1, max_length=50) - session_id: str - - -class AssistantState(TypedDict, total=False): - messages: Annotated[Sequence[AssistantMessageUnion], operator.add] - intermediate_steps: Optional[list[tuple[AgentAction, Optional[str]]]] - plan: Optional[str] - - -class AssistantNodeName(StrEnum): - START = START - END = END - ROUTER = "router" - TRENDS_PLANNER = "trends_planner" - TRENDS_PLANNER_TOOLS = "trends_planner_tools" - TRENDS_GENERATOR = "trends_generator" - TRENDS_GENERATOR_TOOLS = "trends_generator_tools" - FUNNEL_PLANNER = "funnel_planner" - FUNNEL_PLANNER_TOOLS = "funnel_planner_tools" - FUNNEL_GENERATOR = "funnel_generator" - FUNNEL_GENERATOR_TOOLS = "funnel_generator_tools" - SUMMARIZER = "summarizer" - - -class AssistantNode(ABC): - _team: Team - - def __init__(self, team: Team): - self._team = team - - @abstractmethod - def run(cls, state: AssistantState, config: RunnableConfig) -> AssistantState: - raise NotImplementedError - - -def remove_line_breaks(line: str) -> str: - return line.replace("\n", " ") - - -def merge_human_messages(messages: list[LangchainHumanMessage]) -> list[LangchainHumanMessage]: - """ - Filters out duplicated human messages and merges them into one message. - """ - contents = set() - filtered_messages = [] - for message in messages: - if message.content in contents: - continue - contents.add(message.content) - filtered_messages.append(message) - return merge_message_runs(filtered_messages) - - -def filter_visualization_conversation( - messages: Sequence[AssistantMessageUnion], -) -> tuple[list[LangchainHumanMessage], list[VisualizationMessage]]: - """ - Splits, filters and merges the message history to be consumable by agents. Returns human and visualization messages. - """ - stack: list[LangchainHumanMessage] = [] - human_messages: list[LangchainHumanMessage] = [] - visualization_messages: list[VisualizationMessage] = [] - - for message in messages: - if isinstance(message, HumanMessage): - stack.append(LangchainHumanMessage(content=message.content)) - elif isinstance(message, VisualizationMessage) and message.answer: - if stack: - human_messages += merge_human_messages(stack) - stack = [] - visualization_messages.append(message) - - if stack: - human_messages += merge_human_messages(stack) - - return human_messages, visualization_messages - - -def dereference_schema(schema: dict) -> dict: - new_schema: dict = replace_refs(schema, proxies=False, lazy_load=False) - if "$defs" in new_schema: - new_schema.pop("$defs") - return new_schema diff --git a/ee/hogai/utils/__init__.py b/ee/hogai/utils/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/ee/hogai/utils/helpers.py b/ee/hogai/utils/helpers.py new file mode 100644 index 0000000000000..4fc8cf3b5d6a0 --- /dev/null +++ b/ee/hogai/utils/helpers.py @@ -0,0 +1,79 @@ +from collections.abc import Sequence +from typing import Optional, TypeVar, Union + +from jsonref import replace_refs +from langchain_core.messages import ( + HumanMessage as LangchainHumanMessage, + merge_message_runs, +) + +from posthog.schema import ( + AssistantMessage, + HumanMessage, + VisualizationMessage, +) + +from .types import AIMessageUnion, AssistantMessageUnion + + +def remove_line_breaks(line: str) -> str: + return line.replace("\n", " ") + + +def filter_messages( + messages: Sequence[AssistantMessageUnion], + entity_filter: Union[tuple[type[AIMessageUnion], ...], type[AIMessageUnion]] = ( + AssistantMessage, + VisualizationMessage, + ), +) -> list[AssistantMessageUnion]: + """ + Filters and merges the message history to be consumable by agents. Returns human and AI messages. + """ + stack: list[LangchainHumanMessage] = [] + filtered_messages: list[AssistantMessageUnion] = [] + + def _merge_stack(stack: list[LangchainHumanMessage]) -> list[HumanMessage]: + return [ + HumanMessage(content=langchain_message.content, id=langchain_message.id) + for langchain_message in merge_message_runs(stack) + ] + + for message in messages: + if isinstance(message, HumanMessage): + stack.append(LangchainHumanMessage(content=message.content, id=message.id)) + elif isinstance(message, entity_filter): + if stack: + filtered_messages += _merge_stack(stack) + stack = [] + filtered_messages.append(message) + + if stack: + filtered_messages += _merge_stack(stack) + + return filtered_messages + + +T = TypeVar("T", bound=AssistantMessageUnion) + + +def find_last_message_of_type(messages: Sequence[AssistantMessageUnion], message_type: type[T]) -> Optional[T]: + return next((msg for msg in reversed(messages) if isinstance(msg, message_type)), None) + + +def slice_messages_to_conversation_start( + messages: Sequence[AssistantMessageUnion], start_id: Optional[str] = None +) -> Sequence[AssistantMessageUnion]: + result = [] + for msg in messages: + result.append(msg) + if msg.id == start_id: + break + return result + + +def dereference_schema(schema: dict) -> dict: + new_schema: dict = replace_refs(schema, proxies=False, lazy_load=False) + if "$defs" in new_schema: + new_schema.pop("$defs") + return new_schema diff --git a/ee/hogai/utils/nodes.py b/ee/hogai/utils/nodes.py new file mode 100644 index 0000000000000..6a4358243b666 --- /dev/null +++ b/ee/hogai/utils/nodes.py @@ -0,0 +1,18 @@ +from abc import ABC, abstractmethod + +from langchain_core.runnables import RunnableConfig + +from posthog.models.team.team import Team + +from .types import AssistantState, PartialAssistantState + + +class AssistantNode(ABC): + _team: Team + + def __init__(self, team: Team): + self._team = team + + @abstractmethod + def run(cls, state: AssistantState, config: RunnableConfig) -> PartialAssistantState: + raise NotImplementedError diff --git a/ee/hogai/utils/state.py b/ee/hogai/utils/state.py new file mode 100644 index 0000000000000..3392f3362adb9 --- /dev/null +++ b/ee/hogai/utils/state.py @@ -0,0 +1,70 @@ +from typing import Any, Literal, TypedDict, TypeGuard, Union + +from langchain_core.messages import AIMessageChunk + +from ee.hogai.utils.types import AssistantNodeName, AssistantState, PartialAssistantState + +# A state update can have a partial state or a LangGraph's reserved dataclasses like Interrupt. +GraphValueUpdate = dict[AssistantNodeName, dict[Any, Any] | Any] + +GraphValueUpdateTuple = tuple[Literal["values"], GraphValueUpdate] + + +def is_value_update(update: list[Any]) -> TypeGuard[GraphValueUpdateTuple]: + """ + Transition between nodes. + + Returns: + PartialAssistantState, Interrupt, or other LangGraph reserved dataclasses. + """ + return len(update) == 2 and update[0] == "updates" + + +def validate_value_update(update: GraphValueUpdate) -> dict[AssistantNodeName, PartialAssistantState | Any]: + validated_update = {} + for node_name, value in update.items(): + if isinstance(value, dict): + validated_update[node_name] = PartialAssistantState.model_validate(value) + else: + validated_update[node_name] = value + return validated_update + + +class LangGraphState(TypedDict): + langgraph_node: AssistantNodeName + + +GraphMessageUpdateTuple = tuple[Literal["messages"], tuple[Union[AIMessageChunk, Any], LangGraphState]] + + +def is_message_update(update: list[Any]) -> TypeGuard[GraphMessageUpdateTuple]: + """ + Streaming of messages. + """ + return len(update) == 2 and update[0] == "messages" + + +GraphStateUpdateTuple = tuple[Literal["updates"], dict[Any, Any]] + + +def is_state_update(update: list[Any]) -> TypeGuard[GraphStateUpdateTuple]: + """ + Update of the state. Returns a full state. + """ + return len(update) == 2 and update[0] == "values" + + +def validate_state_update(state_update: dict[Any, Any]) -> AssistantState: + return AssistantState.model_validate(state_update) + + +GraphTaskStartedUpdateTuple = tuple[Literal["debug"], tuple[Union[AIMessageChunk, Any], LangGraphState]] + + +def is_task_started_update( + update: list[Any], +) -> TypeGuard[GraphTaskStartedUpdateTuple]: + """ + Streaming of messages. + """ + return len(update) == 2 and update[0] == "debug" and update[1]["type"] == "task" diff --git a/ee/hogai/utils/types.py b/ee/hogai/utils/types.py new file mode 100644 index 0000000000000..2df027b6f85af --- /dev/null +++ b/ee/hogai/utils/types.py @@ -0,0 +1,52 @@ +import operator +from collections.abc import Sequence +from enum import StrEnum +from typing import Annotated, Optional, Union + +from langchain_core.agents import AgentAction +from langgraph.graph import END, START +from pydantic import BaseModel, Field + +from posthog.schema import ( + AssistantMessage, + FailureMessage, + HumanMessage, + ReasoningMessage, + RouterMessage, + VisualizationMessage, +) + +AIMessageUnion = Union[AssistantMessage, VisualizationMessage, FailureMessage, RouterMessage, ReasoningMessage] +AssistantMessageUnion = Union[HumanMessage, AIMessageUnion] + + +class _SharedAssistantState(BaseModel): + intermediate_steps: Optional[list[tuple[AgentAction, Optional[str]]]] = Field(default=None) + start_id: Optional[str] = Field(default=None) + """ + The ID of the message from which the conversation started. + """ + plan: Optional[str] = Field(default=None) + + +class AssistantState(_SharedAssistantState): + messages: Annotated[Sequence[AssistantMessageUnion], operator.add] + + +class PartialAssistantState(_SharedAssistantState): + messages: Optional[Annotated[Sequence[AssistantMessageUnion], operator.add]] = Field(default=None) + + +class AssistantNodeName(StrEnum): + START = START + END = END + ROUTER = "router" + TRENDS_PLANNER = "trends_planner" + TRENDS_PLANNER_TOOLS = "trends_planner_tools" + TRENDS_GENERATOR = "trends_generator" + TRENDS_GENERATOR_TOOLS = "trends_generator_tools" + FUNNEL_PLANNER = "funnel_planner" + FUNNEL_PLANNER_TOOLS = "funnel_planner_tools" + FUNNEL_GENERATOR = "funnel_generator" + FUNNEL_GENERATOR_TOOLS = "funnel_generator_tools" + SUMMARIZER = "summarizer" diff --git a/ee/migrations/0018_conversation_conversationcheckpoint_and_more.py b/ee/migrations/0018_conversation_conversationcheckpoint_and_more.py new file mode 100644 index 0000000000000..ec48cc780ad57 --- /dev/null +++ b/ee/migrations/0018_conversation_conversationcheckpoint_and_more.py @@ -0,0 +1,147 @@ +# Generated by Django 4.2.15 on 2024-12-11 15:51 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion +import posthog.models.utils + + +class Migration(migrations.Migration): + dependencies = [ + ("posthog", "0528_project_field_in_taxonomy"), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ("ee", "0017_accesscontrol_and_more"), + ] + + operations = [ + migrations.CreateModel( + name="Conversation", + fields=[ + ( + "id", + models.UUIDField( + default=posthog.models.utils.UUIDT, editable=False, primary_key=True, serialize=False + ), + ), + ("team", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="posthog.team")), + ("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)), + ], + options={ + "abstract": False, + }, + ), + migrations.CreateModel( + name="ConversationCheckpoint", + fields=[ + ( + "id", + models.UUIDField( + default=posthog.models.utils.UUIDT, editable=False, primary_key=True, serialize=False + ), + ), + ( + "checkpoint_ns", + models.TextField( + default="", + help_text='Checkpoint namespace. Denotes the path to the subgraph node the checkpoint originates from, separated by `|` character, e.g. `"child|grandchild"`. Defaults to "" (root graph).', + ), + ), + ("checkpoint", models.JSONField(help_text="Serialized checkpoint data.", null=True)), + ("metadata", models.JSONField(help_text="Serialized checkpoint metadata.", null=True)), + ( + "parent_checkpoint", + models.ForeignKey( + help_text="Parent checkpoint ID.", + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="children", + to="ee.conversationcheckpoint", + ), + ), + ( + "thread", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, related_name="checkpoints", to="ee.conversation" + ), + ), + ], + ), + migrations.CreateModel( + name="ConversationCheckpointWrite", + fields=[ + ( + "id", + models.UUIDField( + default=posthog.models.utils.UUIDT, editable=False, primary_key=True, serialize=False + ), + ), + ("task_id", models.UUIDField(help_text="Identifier for the task creating the checkpoint write.")), + ( + "idx", + models.IntegerField( + help_text="Index of the checkpoint write. It is an integer value where negative numbers are reserved for special cases, such as node interruption." + ), + ), + ( + "channel", + models.TextField( + help_text="An arbitrary string defining the channel name. For example, it can be a node name or a reserved LangGraph's enum." + ), + ), + ("type", models.TextField(help_text="Type of the serialized blob. For example, `json`.", null=True)), + ("blob", models.BinaryField(null=True)), + ( + "checkpoint", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="writes", + to="ee.conversationcheckpoint", + ), + ), + ], + ), + migrations.CreateModel( + name="ConversationCheckpointBlob", + fields=[ + ( + "id", + models.UUIDField( + default=posthog.models.utils.UUIDT, editable=False, primary_key=True, serialize=False + ), + ), + ( + "channel", + models.TextField( + help_text="An arbitrary string defining the channel name. For example, it can be a node name or a reserved LangGraph's enum." + ), + ), + ("version", models.TextField(help_text="Monotonically increasing version of the channel.")), + ("type", models.TextField(help_text="Type of the serialized blob. For example, `json`.", null=True)), + ("blob", models.BinaryField(null=True)), + ( + "checkpoint", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="blobs", + to="ee.conversationcheckpoint", + ), + ), + ], + ), + migrations.AddConstraint( + model_name="conversationcheckpointwrite", + constraint=models.UniqueConstraint( + fields=("checkpoint_id", "task_id", "idx"), name="unique_checkpoint_write" + ), + ), + migrations.AddConstraint( + model_name="conversationcheckpointblob", + constraint=models.UniqueConstraint( + fields=("checkpoint_id", "channel", "version"), name="unique_checkpoint_blob" + ), + ), + migrations.AddConstraint( + model_name="conversationcheckpoint", + constraint=models.UniqueConstraint(fields=("id", "checkpoint_ns", "thread"), name="unique_checkpoint"), + ), + ] diff --git a/ee/migrations/max_migration.txt b/ee/migrations/max_migration.txt index 449d87290c304..fb889f1cc34cf 100644 --- a/ee/migrations/max_migration.txt +++ b/ee/migrations/max_migration.txt @@ -1 +1 @@ -0017_accesscontrol_and_more +0018_conversation_conversationcheckpoint_and_more diff --git a/ee/models/__init__.py b/ee/models/__init__.py index df7cfcba704e6..2067d11f7618f 100644 --- a/ee/models/__init__.py +++ b/ee/models/__init__.py @@ -1,3 +1,4 @@ +from .assistant import Conversation, ConversationCheckpoint, ConversationCheckpointBlob, ConversationCheckpointWrite from .dashboard_privilege import DashboardPrivilege from .event_definition import EnterpriseEventDefinition from .explicit_team_membership import ExplicitTeamMembership @@ -10,7 +11,11 @@ __all__ = [ "AccessControl", + "ConversationCheckpoint", + "ConversationCheckpointBlob", + "ConversationCheckpointWrite", "DashboardPrivilege", + "Conversation", "EnterpriseEventDefinition", "EnterprisePropertyDefinition", "ExplicitTeamMembership", diff --git a/ee/models/assistant.py b/ee/models/assistant.py new file mode 100644 index 0000000000000..390a7ab7a117f --- /dev/null +++ b/ee/models/assistant.py @@ -0,0 +1,83 @@ +from collections.abc import Iterable + +from django.db import models +from langgraph.checkpoint.serde.types import TASKS + +from posthog.models.team.team import Team +from posthog.models.user import User +from posthog.models.utils import UUIDModel + + +class Conversation(UUIDModel): + user = models.ForeignKey(User, on_delete=models.CASCADE) + team = models.ForeignKey(Team, on_delete=models.CASCADE) + + +class ConversationCheckpoint(UUIDModel): + thread = models.ForeignKey(Conversation, on_delete=models.CASCADE, related_name="checkpoints") + checkpoint_ns = models.TextField( + default="", + help_text='Checkpoint namespace. Denotes the path to the subgraph node the checkpoint originates from, separated by `|` character, e.g. `"child|grandchild"`. Defaults to "" (root graph).', + ) + parent_checkpoint = models.ForeignKey( + "self", null=True, on_delete=models.CASCADE, related_name="children", help_text="Parent checkpoint ID." + ) + checkpoint = models.JSONField(null=True, help_text="Serialized checkpoint data.") + metadata = models.JSONField(null=True, help_text="Serialized checkpoint metadata.") + + class Meta: + constraints = [ + models.UniqueConstraint( + fields=["id", "checkpoint_ns", "thread"], + name="unique_checkpoint", + ) + ] + + @property + def pending_sends(self) -> Iterable["ConversationCheckpointWrite"]: + if self.parent_checkpoint is None: + return [] + return self.parent_checkpoint.writes.filter(channel=TASKS).order_by("task_id", "idx") + + @property + def pending_writes(self) -> Iterable["ConversationCheckpointWrite"]: + return self.writes.order_by("idx", "task_id") + + +class ConversationCheckpointBlob(UUIDModel): + checkpoint = models.ForeignKey(ConversationCheckpoint, on_delete=models.CASCADE, related_name="blobs") + channel = models.TextField( + help_text="An arbitrary string defining the channel name. For example, it can be a node name or a reserved LangGraph's enum." + ) + version = models.TextField(help_text="Monotonically increasing version of the channel.") + type = models.TextField(null=True, help_text="Type of the serialized blob. For example, `json`.") + blob = models.BinaryField(null=True) + + class Meta: + constraints = [ + models.UniqueConstraint( + fields=["checkpoint_id", "channel", "version"], + name="unique_checkpoint_blob", + ) + ] + + +class ConversationCheckpointWrite(UUIDModel): + checkpoint = models.ForeignKey(ConversationCheckpoint, on_delete=models.CASCADE, related_name="writes") + task_id = models.UUIDField(help_text="Identifier for the task creating the checkpoint write.") + idx = models.IntegerField( + help_text="Index of the checkpoint write. It is an integer value where negative numbers are reserved for special cases, such as node interruption." + ) + channel = models.TextField( + help_text="An arbitrary string defining the channel name. For example, it can be a node name or a reserved LangGraph's enum." + ) + type = models.TextField(null=True, help_text="Type of the serialized blob. For example, `json`.") + blob = models.BinaryField(null=True) + + class Meta: + constraints = [ + models.UniqueConstraint( + fields=["checkpoint_id", "task_id", "idx"], + name="unique_checkpoint_write", + ) + ] diff --git a/ee/urls.py b/ee/urls.py index 7c722bc31852f..91b58e0fcb238 100644 --- a/ee/urls.py +++ b/ee/urls.py @@ -6,11 +6,11 @@ from django.urls.conf import path from ee.api import integration -from .api.rbac import organization_resource_access, role from .api import ( authentication, billing, + conversation, dashboard_collaborator, explicit_team_member, feature_flag_role_access, @@ -19,18 +19,20 @@ sentry_stats, subscription, ) +from .api.rbac import organization_resource_access, role from .session_recordings import session_recording_playlist def extend_api_router() -> None: from posthog.api import ( - router as root_router, - register_grandfathered_environment_nested_viewset, - projects_router, - organizations_router, - project_feature_flags_router, environment_dashboards_router, + environments_router, legacy_project_dashboards_router, + organizations_router, + project_feature_flags_router, + projects_router, + register_grandfathered_environment_nested_viewset, + router as root_router, ) root_router.register(r"billing", billing.BillingViewset, "billing") @@ -93,6 +95,10 @@ def extend_api_router() -> None: ["project_id"], ) + environments_router.register( + r"conversations", conversation.ConversationViewSet, "environment_conversations", ["team_id"] + ) + # The admin interface is disabled on self-hosted instances, as its misuse can be unsafe admin_urlpatterns = ( diff --git a/frontend/__snapshots__/scenes-app-max-ai--empty-thread-loading--dark.png b/frontend/__snapshots__/scenes-app-max-ai--empty-thread-loading--dark.png index ab8433877a9e5..64cc334785eda 100644 Binary files a/frontend/__snapshots__/scenes-app-max-ai--empty-thread-loading--dark.png and b/frontend/__snapshots__/scenes-app-max-ai--empty-thread-loading--dark.png differ diff --git a/frontend/__snapshots__/scenes-app-max-ai--empty-thread-loading--light.png b/frontend/__snapshots__/scenes-app-max-ai--empty-thread-loading--light.png index e5bac9eeee8f1..756eb74f6156b 100644 Binary files a/frontend/__snapshots__/scenes-app-max-ai--empty-thread-loading--light.png and b/frontend/__snapshots__/scenes-app-max-ai--empty-thread-loading--light.png differ diff --git a/frontend/__snapshots__/scenes-app-max-ai--generation-failure-thread--dark.png b/frontend/__snapshots__/scenes-app-max-ai--generation-failure-thread--dark.png index db43716077ebe..27244118ba4a0 100644 Binary files a/frontend/__snapshots__/scenes-app-max-ai--generation-failure-thread--dark.png and b/frontend/__snapshots__/scenes-app-max-ai--generation-failure-thread--dark.png differ diff --git a/frontend/__snapshots__/scenes-app-max-ai--generation-failure-thread--light.png b/frontend/__snapshots__/scenes-app-max-ai--generation-failure-thread--light.png index cc409784b1c45..6f13b687b5156 100644 Binary files a/frontend/__snapshots__/scenes-app-max-ai--generation-failure-thread--light.png and b/frontend/__snapshots__/scenes-app-max-ai--generation-failure-thread--light.png differ diff --git a/frontend/__snapshots__/scenes-app-max-ai--thread--dark.png b/frontend/__snapshots__/scenes-app-max-ai--thread--dark.png index 6f6b491670b99..78d4b44700226 100644 Binary files a/frontend/__snapshots__/scenes-app-max-ai--thread--dark.png and b/frontend/__snapshots__/scenes-app-max-ai--thread--dark.png differ diff --git a/frontend/__snapshots__/scenes-app-max-ai--thread--light.png b/frontend/__snapshots__/scenes-app-max-ai--thread--light.png index c456bd0c5fde0..d300a75be4dc6 100644 Binary files a/frontend/__snapshots__/scenes-app-max-ai--thread--light.png and b/frontend/__snapshots__/scenes-app-max-ai--thread--light.png differ diff --git a/frontend/__snapshots__/scenes-app-max-ai--thread-with-rate-limit--dark.png b/frontend/__snapshots__/scenes-app-max-ai--thread-with-rate-limit--dark.png index 803390e1b9822..f7e1627246573 100644 Binary files a/frontend/__snapshots__/scenes-app-max-ai--thread-with-rate-limit--dark.png and b/frontend/__snapshots__/scenes-app-max-ai--thread-with-rate-limit--dark.png differ diff --git a/frontend/__snapshots__/scenes-app-max-ai--thread-with-rate-limit--light.png b/frontend/__snapshots__/scenes-app-max-ai--thread-with-rate-limit--light.png index b785e83e2206f..1bcb94198a8b6 100644 Binary files a/frontend/__snapshots__/scenes-app-max-ai--thread-with-rate-limit--light.png and b/frontend/__snapshots__/scenes-app-max-ai--thread-with-rate-limit--light.png differ diff --git a/frontend/__snapshots__/scenes-app-max-ai--welcome-loading-suggestions--dark.png b/frontend/__snapshots__/scenes-app-max-ai--welcome-loading-suggestions--dark.png index be67ab75d125b..eee71a3783a06 100644 Binary files a/frontend/__snapshots__/scenes-app-max-ai--welcome-loading-suggestions--dark.png and b/frontend/__snapshots__/scenes-app-max-ai--welcome-loading-suggestions--dark.png differ diff --git a/frontend/__snapshots__/scenes-app-max-ai--welcome-loading-suggestions--light.png b/frontend/__snapshots__/scenes-app-max-ai--welcome-loading-suggestions--light.png index cf026b65c4994..358c01f1ccdae 100644 Binary files a/frontend/__snapshots__/scenes-app-max-ai--welcome-loading-suggestions--light.png and b/frontend/__snapshots__/scenes-app-max-ai--welcome-loading-suggestions--light.png differ diff --git a/frontend/src/lib/api.ts b/frontend/src/lib/api.ts index 37d394a7fa483..f1497f937c334 100644 --- a/frontend/src/lib/api.ts +++ b/frontend/src/lib/api.ts @@ -845,9 +845,9 @@ class ApiRequest { return apiRequest } - // Chat - public chat(teamId?: TeamType['id']): ApiRequest { - return this.environmentsDetail(teamId).addPathComponent('query').addPathComponent('chat') + // Conversations + public conversations(teamId?: TeamType['id']): ApiRequest { + return this.environmentsDetail(teamId).addPathComponent('conversations') } // Notebooks @@ -2547,12 +2547,10 @@ const api = { }) }, - chatURL: (): string => { - return new ApiRequest().chat().assembleFullUrl() - }, - - async chat(data: any): Promise { - return await api.createResponse(this.chatURL(), data) + conversations: { + async create(data: { content: string; conversation?: string | null }): Promise { + return api.createResponse(new ApiRequest().conversations().assembleFullUrl(), data) + }, }, /** Fetch data from specified URL. The result already is JSON-parsed. */ diff --git a/frontend/src/queries/schema.json b/frontend/src/queries/schema.json index 95267b3f2434a..c0d0c95abccc8 100644 --- a/frontend/src/queries/schema.json +++ b/frontend/src/queries/schema.json @@ -585,7 +585,7 @@ "type": "string" }, "AssistantEventType": { - "enum": ["status", "message"], + "enum": ["status", "message", "conversation"], "type": "string" }, "AssistantFunnelsBreakdownFilter": { @@ -1092,9 +1092,8 @@ "content": { "type": "string" }, - "done": { - "description": "We only need this \"done\" value to tell when the particular message is finished during its streaming. It won't be necessary when we optimize streaming to NOT send the entire message every time a character is added.", - "type": "boolean" + "id": { + "type": "string" }, "type": { "const": "ai", @@ -1469,6 +1468,15 @@ ], "type": "string" }, + "BaseAssistantMessage": { + "additionalProperties": false, + "properties": { + "id": { + "type": "string" + } + }, + "type": "object" + }, "BaseMathType": { "enum": [ "total", @@ -6220,16 +6228,15 @@ "content": { "type": "string" }, - "done": { - "const": true, - "type": "boolean" + "id": { + "type": "string" }, "type": { "const": "ai/failure", "type": "string" } }, - "required": ["type", "done"], + "required": ["type"], "type": "object" }, "FeaturePropertyFilter": { @@ -7559,17 +7566,15 @@ "content": { "type": "string" }, - "done": { - "const": true, - "description": "Human messages are only appended when done.", - "type": "boolean" + "id": { + "type": "string" }, "type": { "const": "human", "type": "string" } }, - "required": ["type", "content", "done"], + "required": ["type", "content"], "type": "object" }, "InsightActorsQuery": { @@ -11175,9 +11180,8 @@ "content": { "type": "string" }, - "done": { - "const": true, - "type": "boolean" + "id": { + "type": "string" }, "substeps": { "items": { @@ -11190,7 +11194,7 @@ "type": "string" } }, - "required": ["type", "content", "done"], + "required": ["type", "content"], "type": "object" }, "RecordingOrder": { @@ -11637,17 +11641,15 @@ "content": { "type": "string" }, - "done": { - "const": true, - "description": "Router messages are not streamed, so they can only be done.", - "type": "boolean" + "id": { + "type": "string" }, "type": { "const": "ai/router", "type": "string" } }, - "required": ["type", "content", "done"], + "required": ["type", "content"], "type": "object" }, "SamplingRate": { @@ -12827,8 +12829,11 @@ } ] }, - "done": { - "type": "boolean" + "id": { + "type": "string" + }, + "initiator": { + "type": "string" }, "plan": { "type": "string" diff --git a/frontend/src/queries/schema.ts b/frontend/src/queries/schema.ts index 10aa6ac455540..89c7b9786b9b5 100644 --- a/frontend/src/queries/schema.ts +++ b/frontend/src/queries/schema.ts @@ -2490,48 +2490,41 @@ export enum AssistantMessageType { Router = 'ai/router', } -export interface HumanMessage { +export interface BaseAssistantMessage { + id?: string +} + +export interface HumanMessage extends BaseAssistantMessage { type: AssistantMessageType.Human content: string - /** Human messages are only appended when done. */ - done: true } -export interface AssistantMessage { +export interface AssistantMessage extends BaseAssistantMessage { type: AssistantMessageType.Assistant content: string - /** - * We only need this "done" value to tell when the particular message is finished during its streaming. - * It won't be necessary when we optimize streaming to NOT send the entire message every time a character is added. - */ - done?: boolean } -export interface ReasoningMessage { +export interface ReasoningMessage extends BaseAssistantMessage { type: AssistantMessageType.Reasoning content: string substeps?: string[] - done: true } -export interface VisualizationMessage { +export interface VisualizationMessage extends BaseAssistantMessage { type: AssistantMessageType.Visualization plan?: string answer?: AssistantTrendsQuery | AssistantFunnelsQuery - done?: boolean + initiator?: string } -export interface FailureMessage { +export interface FailureMessage extends BaseAssistantMessage { type: AssistantMessageType.Failure content?: string - done: true } -export interface RouterMessage { +export interface RouterMessage extends BaseAssistantMessage { type: AssistantMessageType.Router content: string - /** Router messages are not streamed, so they can only be done. */ - done: true } export type RootAssistantMessage = @@ -2545,6 +2538,7 @@ export type RootAssistantMessage = export enum AssistantEventType { Status = 'status', Message = 'message', + Conversation = 'conversation', } export enum AssistantGenerationStatusType { diff --git a/frontend/src/scenes/max/Intro.tsx b/frontend/src/scenes/max/Intro.tsx index c43cd86b53d2a..97f4f9fbfdc56 100644 --- a/frontend/src/scenes/max/Intro.tsx +++ b/frontend/src/scenes/max/Intro.tsx @@ -3,6 +3,7 @@ import { LemonButton, Popover } from '@posthog/lemon-ui' import { useActions, useValues } from 'kea' import { HedgehogBuddy } from 'lib/components/HedgehogBuddy/HedgehogBuddy' import { hedgehogBuddyLogic } from 'lib/components/HedgehogBuddy/hedgehogBuddyLogic' +import { uuid } from 'lib/utils' import { useMemo, useState } from 'react' import { maxGlobalLogic } from './maxGlobalLogic' @@ -19,13 +20,13 @@ export function Intro(): JSX.Element { const { hedgehogConfig } = useValues(hedgehogBuddyLogic) const { acceptDataProcessing } = useActions(maxGlobalLogic) const { dataProcessingAccepted } = useValues(maxGlobalLogic) - const { sessionId } = useValues(maxLogic) + const { conversation } = useValues(maxLogic) const [hedgehogDirection, setHedgehogDirection] = useState<'left' | 'right'>('right') const headline = useMemo(() => { - return HEADLINES[parseInt(sessionId.split('-').at(-1) as string, 16) % HEADLINES.length] - }, []) + return HEADLINES[parseInt((conversation?.id || uuid()).split('-').at(-1) as string, 16) % HEADLINES.length] + }, [conversation?.id]) return ( <> diff --git a/frontend/src/scenes/max/Max.stories.tsx b/frontend/src/scenes/max/Max.stories.tsx index bec5a519de8e0..51dc03ab0cb5c 100644 --- a/frontend/src/scenes/max/Max.stories.tsx +++ b/frontend/src/scenes/max/Max.stories.tsx @@ -6,7 +6,13 @@ import { projectLogic } from 'scenes/projectLogic' import { mswDecorator, useStorybookMocks } from '~/mocks/browser' -import { chatResponseChunk, failureChunk, generationFailureChunk } from './__mocks__/chatResponse.mocks' +import { + chatResponseChunk, + CONVERSATION_ID, + failureChunk, + generationFailureChunk, + humanMessage, +} from './__mocks__/chatResponse.mocks' import { MaxInstance } from './Max' import { maxGlobalLogic } from './maxGlobalLogic' import { maxLogic } from './maxLogic' @@ -16,7 +22,7 @@ const meta: Meta = { decorators: [ mswDecorator({ post: { - '/api/environments/:team_id/query/chat/': (_, res, ctx) => res(ctx.text(chatResponseChunk)), + '/api/environments/:team_id/conversations/': (_, res, ctx) => res(ctx.text(chatResponseChunk)), }, }), ], @@ -28,10 +34,7 @@ const meta: Meta = { } export default meta -// The session ID is hard-coded here, as it's used for randomizing the welcome headline -const SESSION_ID = 'b1b4b3b4-1b3b-4b3b-1b3b4b3b4b3b' - -const Template = ({ sessionId: SESSION_ID }: { sessionId: string }): JSX.Element => { +const Template = ({ conversationId: CONVERSATION_ID }: { conversationId: string }): JSX.Element => { const { acceptDataProcessing } = useActions(maxGlobalLogic) useEffect(() => { @@ -40,7 +43,7 @@ const Template = ({ sessionId: SESSION_ID }: { sessionId: string }): JSX.Element return (
- +
@@ -69,7 +72,7 @@ export const Welcome: StoryFn = () => { acceptDataProcessing(false) }, []) - return