Skip to content

Commit

Permalink
Merge branch 'master' into feat/slow-rollout-decide
Browse files Browse the repository at this point in the history
  • Loading branch information
benjackwhite authored Dec 18, 2024
2 parents da8c194 + 6dca54c commit 074201a
Show file tree
Hide file tree
Showing 154 changed files with 3,373 additions and 1,729 deletions.
6 changes: 3 additions & 3 deletions cypress/e2e/surveys.cy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ describe('Surveys', () => {
cy.get('.LemonCollapsePanel').contains('Display conditions').click()
cy.contains('All users').click()
cy.get('.Popover__content').contains('Users who match').click()
cy.contains('Add user targeting').click()
cy.contains('Add property targeting').click()

// select the first property
cy.get('[data-attr="property-select-toggle-0"]').click()
Expand Down Expand Up @@ -144,7 +144,7 @@ describe('Surveys', () => {

// remove user targeting properties
cy.get('.LemonCollapsePanel').contains('Display conditions').click()
cy.contains('Remove all user properties').click()
cy.contains('Remove all property targeting').click()

// save
cy.get('[data-attr="save-survey"]').eq(0).click()
Expand Down Expand Up @@ -197,7 +197,7 @@ describe('Surveys', () => {
cy.get('.LemonCollapsePanel').contains('Display conditions').click()
cy.contains('All users').click()
cy.get('.Popover__content').contains('Users who match').click()
cy.contains('Add user targeting').click()
cy.contains('Add property targeting').click()
cy.get('[data-attr="property-select-toggle-0"]').click()
cy.get('[data-attr="prop-filter-person_properties-0"]').click()
cy.get('[data-attr=prop-val]').click({ force: true })
Expand Down
69 changes: 69 additions & 0 deletions ee/api/conversation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import cast

from django.http import StreamingHttpResponse
from pydantic import ValidationError
from rest_framework import serializers
from rest_framework.renderers import BaseRenderer
from rest_framework.request import Request
from rest_framework.viewsets import GenericViewSet

from ee.hogai.assistant import Assistant
from ee.models.assistant import Conversation
from posthog.api.routing import TeamAndOrgViewSetMixin
from posthog.models.user import User
from posthog.rate_limit import AIBurstRateThrottle, AISustainedRateThrottle
from posthog.schema import HumanMessage


class MessageSerializer(serializers.Serializer):
content = serializers.CharField(required=True, max_length=1000)
conversation = serializers.UUIDField(required=False)

def validate(self, data):
try:
message = HumanMessage(content=data["content"])
data["message"] = message
except ValidationError:
raise serializers.ValidationError("Invalid message content.")
return data


class ServerSentEventRenderer(BaseRenderer):
media_type = "text/event-stream"
format = "txt"

def render(self, data, accepted_media_type=None, renderer_context=None):
return data


class ConversationViewSet(TeamAndOrgViewSetMixin, GenericViewSet):
scope_object = "INTERNAL"
serializer_class = MessageSerializer
renderer_classes = [ServerSentEventRenderer]
queryset = Conversation.objects.all()
lookup_url_kwarg = "conversation"

def safely_get_queryset(self, queryset):
# Only allow access to conversations created by the current user
return queryset.filter(user=self.request.user)

def get_throttles(self):
return [AIBurstRateThrottle(), AISustainedRateThrottle()]

def create(self, request: Request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
conversation_id = serializer.validated_data.get("conversation")
if conversation_id:
self.kwargs[self.lookup_url_kwarg] = conversation_id
conversation = self.get_object()
else:
conversation = self.get_queryset().create(user=request.user, team=self.team)
assistant = Assistant(
self.team,
conversation,
serializer.validated_data["message"],
user=cast(User, request.user),
is_new_conversation=not conversation_id,
)
return StreamingHttpResponse(assistant.stream(), content_type=ServerSentEventRenderer.media_type)
157 changes: 157 additions & 0 deletions ee/api/test/test_conversation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from unittest.mock import patch

from rest_framework import status

from ee.hogai.assistant import Assistant
from ee.models.assistant import Conversation
from posthog.models.team.team import Team
from posthog.models.user import User
from posthog.test.base import APIBaseTest


class TestConversation(APIBaseTest):
def setUp(self):
super().setUp()
self.other_team = Team.objects.create(organization=self.organization, name="other team")
self.other_user = User.objects.create_and_join(
organization=self.organization,
email="[email protected]",
password="password",
first_name="Other",
)

def _get_streaming_content(self, response):
return b"".join(response.streaming_content)

def test_create_conversation(self):
with patch.object(Assistant, "_stream", return_value=["test response"]) as stream_mock:
response = self.client.post(
f"/api/environments/{self.team.id}/conversations/",
{"content": "test query"},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(self._get_streaming_content(response), b"test response")
self.assertEqual(Conversation.objects.count(), 1)
conversation: Conversation = Conversation.objects.first()
self.assertEqual(conversation.user, self.user)
self.assertEqual(conversation.team, self.team)
stream_mock.assert_called_once()

def test_add_message_to_existing_conversation(self):
with patch.object(Assistant, "_stream", return_value=["test response"]) as stream_mock:
conversation = Conversation.objects.create(user=self.user, team=self.team)
response = self.client.post(
f"/api/environments/{self.team.id}/conversations/",
{
"conversation": str(conversation.id),
"content": "test query",
},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(self._get_streaming_content(response), b"test response")
self.assertEqual(Conversation.objects.count(), 1)
stream_mock.assert_called_once()

def test_cant_access_other_users_conversation(self):
conversation = Conversation.objects.create(user=self.other_user, team=self.team)

self.client.force_login(self.user)
response = self.client.post(
f"/api/environments/{self.team.id}/conversations/",
{"conversation": conversation.id, "content": "test query"},
)

self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)

def test_cant_access_other_teams_conversation(self):
conversation = Conversation.objects.create(user=self.user, team=self.other_team)
response = self.client.post(
f"/api/environments/{self.team.id}/conversations/",
{"conversation": conversation.id, "content": "test query"},
)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)

def test_invalid_message_format(self):
response = self.client.post("/api/environments/@current/conversations/")
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

def test_rate_limit_burst(self):
# Create multiple requests to trigger burst rate limit
with patch.object(Assistant, "_stream", return_value=["test response"]):
for _ in range(11): # Assuming burst limit is less than this
response = self.client.post(
f"/api/environments/{self.team.id}/conversations/",
{"content": "test query"},
)
self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS)

def test_empty_content(self):
response = self.client.post(
f"/api/environments/{self.team.id}/conversations/",
{"content": ""},
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

def test_content_too_long(self):
response = self.client.post(
f"/api/environments/{self.team.id}/conversations/",
{"content": "x" * 1001}, # Very long message
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

def test_invalid_conversation_id(self):
response = self.client.post(
f"/api/environments/{self.team.id}/conversations/",
{
"conversation": "not-a-valid-uuid",
"content": "test query",
},
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

def test_nonexistent_conversation(self):
response = self.client.post(
f"/api/environments/{self.team.id}/conversations/",
{
"conversation": "12345678-1234-5678-1234-567812345678",
"content": "test query",
},
)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)

def test_deleted_conversation(self):
# Create and then delete a conversation
conversation = Conversation.objects.create(user=self.user, team=self.team)
conversation_id = conversation.id
conversation.delete()

response = self.client.post(
f"/api/environments/{self.team.id}/conversations/",
{
"conversation": str(conversation_id),
"content": "test query",
},
)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)

def test_unauthenticated_request(self):
self.client.logout()
response = self.client.post(
f"/api/environments/{self.team.id}/conversations/",
{"content": "test query"},
)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)

def test_streaming_error_handling(self):
def raise_error():
yield "some content"
raise Exception("Streaming error")

with patch.object(Assistant, "_stream", side_effect=raise_error):
response = self.client.post(
f"/api/environments/{self.team.id}/conversations/",
{"content": "test query"},
)
with self.assertRaises(Exception) as context:
b"".join(response.streaming_content)
self.assertTrue("Streaming error" in str(context.exception))
Loading

0 comments on commit 074201a

Please sign in to comment.