-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into feat/slow-rollout-decide
- Loading branch information
Showing
154 changed files
with
3,373 additions
and
1,729 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from typing import cast | ||
|
||
from django.http import StreamingHttpResponse | ||
from pydantic import ValidationError | ||
from rest_framework import serializers | ||
from rest_framework.renderers import BaseRenderer | ||
from rest_framework.request import Request | ||
from rest_framework.viewsets import GenericViewSet | ||
|
||
from ee.hogai.assistant import Assistant | ||
from ee.models.assistant import Conversation | ||
from posthog.api.routing import TeamAndOrgViewSetMixin | ||
from posthog.models.user import User | ||
from posthog.rate_limit import AIBurstRateThrottle, AISustainedRateThrottle | ||
from posthog.schema import HumanMessage | ||
|
||
|
||
class MessageSerializer(serializers.Serializer): | ||
content = serializers.CharField(required=True, max_length=1000) | ||
conversation = serializers.UUIDField(required=False) | ||
|
||
def validate(self, data): | ||
try: | ||
message = HumanMessage(content=data["content"]) | ||
data["message"] = message | ||
except ValidationError: | ||
raise serializers.ValidationError("Invalid message content.") | ||
return data | ||
|
||
|
||
class ServerSentEventRenderer(BaseRenderer): | ||
media_type = "text/event-stream" | ||
format = "txt" | ||
|
||
def render(self, data, accepted_media_type=None, renderer_context=None): | ||
return data | ||
|
||
|
||
class ConversationViewSet(TeamAndOrgViewSetMixin, GenericViewSet): | ||
scope_object = "INTERNAL" | ||
serializer_class = MessageSerializer | ||
renderer_classes = [ServerSentEventRenderer] | ||
queryset = Conversation.objects.all() | ||
lookup_url_kwarg = "conversation" | ||
|
||
def safely_get_queryset(self, queryset): | ||
# Only allow access to conversations created by the current user | ||
return queryset.filter(user=self.request.user) | ||
|
||
def get_throttles(self): | ||
return [AIBurstRateThrottle(), AISustainedRateThrottle()] | ||
|
||
def create(self, request: Request, *args, **kwargs): | ||
serializer = self.get_serializer(data=request.data) | ||
serializer.is_valid(raise_exception=True) | ||
conversation_id = serializer.validated_data.get("conversation") | ||
if conversation_id: | ||
self.kwargs[self.lookup_url_kwarg] = conversation_id | ||
conversation = self.get_object() | ||
else: | ||
conversation = self.get_queryset().create(user=request.user, team=self.team) | ||
assistant = Assistant( | ||
self.team, | ||
conversation, | ||
serializer.validated_data["message"], | ||
user=cast(User, request.user), | ||
is_new_conversation=not conversation_id, | ||
) | ||
return StreamingHttpResponse(assistant.stream(), content_type=ServerSentEventRenderer.media_type) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
from unittest.mock import patch | ||
|
||
from rest_framework import status | ||
|
||
from ee.hogai.assistant import Assistant | ||
from ee.models.assistant import Conversation | ||
from posthog.models.team.team import Team | ||
from posthog.models.user import User | ||
from posthog.test.base import APIBaseTest | ||
|
||
|
||
class TestConversation(APIBaseTest): | ||
def setUp(self): | ||
super().setUp() | ||
self.other_team = Team.objects.create(organization=self.organization, name="other team") | ||
self.other_user = User.objects.create_and_join( | ||
organization=self.organization, | ||
email="[email protected]", | ||
password="password", | ||
first_name="Other", | ||
) | ||
|
||
def _get_streaming_content(self, response): | ||
return b"".join(response.streaming_content) | ||
|
||
def test_create_conversation(self): | ||
with patch.object(Assistant, "_stream", return_value=["test response"]) as stream_mock: | ||
response = self.client.post( | ||
f"/api/environments/{self.team.id}/conversations/", | ||
{"content": "test query"}, | ||
) | ||
self.assertEqual(response.status_code, status.HTTP_200_OK) | ||
self.assertEqual(self._get_streaming_content(response), b"test response") | ||
self.assertEqual(Conversation.objects.count(), 1) | ||
conversation: Conversation = Conversation.objects.first() | ||
self.assertEqual(conversation.user, self.user) | ||
self.assertEqual(conversation.team, self.team) | ||
stream_mock.assert_called_once() | ||
|
||
def test_add_message_to_existing_conversation(self): | ||
with patch.object(Assistant, "_stream", return_value=["test response"]) as stream_mock: | ||
conversation = Conversation.objects.create(user=self.user, team=self.team) | ||
response = self.client.post( | ||
f"/api/environments/{self.team.id}/conversations/", | ||
{ | ||
"conversation": str(conversation.id), | ||
"content": "test query", | ||
}, | ||
) | ||
self.assertEqual(response.status_code, status.HTTP_200_OK) | ||
self.assertEqual(self._get_streaming_content(response), b"test response") | ||
self.assertEqual(Conversation.objects.count(), 1) | ||
stream_mock.assert_called_once() | ||
|
||
def test_cant_access_other_users_conversation(self): | ||
conversation = Conversation.objects.create(user=self.other_user, team=self.team) | ||
|
||
self.client.force_login(self.user) | ||
response = self.client.post( | ||
f"/api/environments/{self.team.id}/conversations/", | ||
{"conversation": conversation.id, "content": "test query"}, | ||
) | ||
|
||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) | ||
|
||
def test_cant_access_other_teams_conversation(self): | ||
conversation = Conversation.objects.create(user=self.user, team=self.other_team) | ||
response = self.client.post( | ||
f"/api/environments/{self.team.id}/conversations/", | ||
{"conversation": conversation.id, "content": "test query"}, | ||
) | ||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) | ||
|
||
def test_invalid_message_format(self): | ||
response = self.client.post("/api/environments/@current/conversations/") | ||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) | ||
|
||
def test_rate_limit_burst(self): | ||
# Create multiple requests to trigger burst rate limit | ||
with patch.object(Assistant, "_stream", return_value=["test response"]): | ||
for _ in range(11): # Assuming burst limit is less than this | ||
response = self.client.post( | ||
f"/api/environments/{self.team.id}/conversations/", | ||
{"content": "test query"}, | ||
) | ||
self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) | ||
|
||
def test_empty_content(self): | ||
response = self.client.post( | ||
f"/api/environments/{self.team.id}/conversations/", | ||
{"content": ""}, | ||
) | ||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) | ||
|
||
def test_content_too_long(self): | ||
response = self.client.post( | ||
f"/api/environments/{self.team.id}/conversations/", | ||
{"content": "x" * 1001}, # Very long message | ||
) | ||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) | ||
|
||
def test_invalid_conversation_id(self): | ||
response = self.client.post( | ||
f"/api/environments/{self.team.id}/conversations/", | ||
{ | ||
"conversation": "not-a-valid-uuid", | ||
"content": "test query", | ||
}, | ||
) | ||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) | ||
|
||
def test_nonexistent_conversation(self): | ||
response = self.client.post( | ||
f"/api/environments/{self.team.id}/conversations/", | ||
{ | ||
"conversation": "12345678-1234-5678-1234-567812345678", | ||
"content": "test query", | ||
}, | ||
) | ||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) | ||
|
||
def test_deleted_conversation(self): | ||
# Create and then delete a conversation | ||
conversation = Conversation.objects.create(user=self.user, team=self.team) | ||
conversation_id = conversation.id | ||
conversation.delete() | ||
|
||
response = self.client.post( | ||
f"/api/environments/{self.team.id}/conversations/", | ||
{ | ||
"conversation": str(conversation_id), | ||
"content": "test query", | ||
}, | ||
) | ||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) | ||
|
||
def test_unauthenticated_request(self): | ||
self.client.logout() | ||
response = self.client.post( | ||
f"/api/environments/{self.team.id}/conversations/", | ||
{"content": "test query"}, | ||
) | ||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) | ||
|
||
def test_streaming_error_handling(self): | ||
def raise_error(): | ||
yield "some content" | ||
raise Exception("Streaming error") | ||
|
||
with patch.object(Assistant, "_stream", side_effect=raise_error): | ||
response = self.client.post( | ||
f"/api/environments/{self.team.id}/conversations/", | ||
{"content": "test query"}, | ||
) | ||
with self.assertRaises(Exception) as context: | ||
b"".join(response.streaming_content) | ||
self.assertTrue("Streaming error" in str(context.exception)) |
Oops, something went wrong.