diff --git a/posthog/api/organization_invite.py b/posthog/api/organization_invite.py index 99b6ddbf7bef0..de572de9656a9 100644 --- a/posthog/api/organization_invite.py +++ b/posthog/api/organization_invite.py @@ -1,6 +1,9 @@ +from datetime import datetime, timedelta from typing import Any, Optional, cast +from uuid import UUID import posthoganalytics +from django.db.models import QuerySet from rest_framework import ( exceptions, mixins, @@ -15,6 +18,7 @@ from posthog.api.routing import TeamAndOrgViewSetMixin from posthog.api.shared import UserBasicSerializer from posthog.api.utils import action +from posthog.constants import INVITE_DAYS_VALIDITY from posthog.email import is_email_available from posthog.event_usage import report_bulk_invited, report_team_member_invited from posthog.models import OrganizationInvite, OrganizationMembership @@ -24,9 +28,85 @@ from posthog.tasks.email import send_invite +class OrganizationInviteManager: + @staticmethod + def combine_invites( + organization_id: UUID | str, validated_data: dict[str, Any], combine_pending_invites: bool = True + ) -> dict[str, Any]: + """Combines multiple pending invites for the same email address.""" + if not combine_pending_invites: + return validated_data + + existing_invites = OrganizationInviteManager._get_invites_for_user_org( + organization_id=organization_id, target_email=validated_data["target_email"] + ) + + if not existing_invites.exists(): + return validated_data + + validated_data["level"] = OrganizationInviteManager._get_highest_level( + existing_invites=existing_invites, + new_level=validated_data.get("level", OrganizationMembership.Level.MEMBER), + ) + + validated_data["private_project_access"] = OrganizationInviteManager._combine_project_access( + existing_invites=existing_invites, new_access=validated_data.get("private_project_access", []) + ) + + return validated_data + + @staticmethod + def _get_invites_for_user_org( + organization_id: UUID | str, target_email: str, include_expired: bool = False + ) -> QuerySet: + filters: dict[str, Any] = { + "organization_id": organization_id, + "target_email": target_email, + } + + if not include_expired: + filters["created_at__gt"] = datetime.now() - timedelta(days=INVITE_DAYS_VALIDITY) + + return OrganizationInvite.objects.filter(**filters).order_by("-created_at") + + @staticmethod + def _get_highest_level(existing_invites: QuerySet, new_level: int) -> int: + levels = [invite.level for invite in existing_invites] + levels.append(new_level) + return max(levels) + + @staticmethod + def _combine_project_access(existing_invites: QuerySet, new_access: list[dict]) -> list[dict]: + combined_access: dict[int, int] = {} + + # Add new access first + for access in new_access: + combined_access[access["id"]] = access["level"] + + # Combine with existing access, keeping highest levels + for invite in existing_invites: + if not invite.private_project_access: + continue + + for access in invite.private_project_access: + project_id = access["id"] + if project_id not in combined_access or access["level"] > combined_access[project_id]: + combined_access[project_id] = access["level"] + + return [{"id": project_id, "level": level} for project_id, level in combined_access.items()] + + @staticmethod + def delete_existing_invites(organization_id: UUID | str, target_email: str) -> None: + """Deletes all existing invites for a given email in an organization.""" + OrganizationInviteManager._get_invites_for_user_org( + organization_id=organization_id, target_email=target_email, include_expired=True + ).delete() + + class OrganizationInviteSerializer(serializers.ModelSerializer): created_by = UserBasicSerializer(read_only=True) send_email = serializers.BooleanField(write_only=True, default=True) + combine_pending_invites = serializers.BooleanField(write_only=True, default=False) class Meta: model = OrganizationInvite @@ -43,6 +123,7 @@ class Meta: "message", "private_project_access", "send_email", + "combine_pending_invites", ] read_only_fields = [ "id", @@ -107,12 +188,30 @@ def create(self, validated_data: dict[str, Any], *args: Any, **kwargs: Any) -> O user__email=validated_data["target_email"], ).exists(): raise exceptions.ValidationError("A user with this email address already belongs to the organization.") + + combine_pending_invites = validated_data.pop("combine_pending_invites", False) send_email = validated_data.pop("send_email", True) + + # Handle invite combination if requested + if combine_pending_invites: + validated_data = OrganizationInviteManager.combine_invites( + organization_id=self.context["organization_id"], + validated_data=validated_data, + combine_pending_invites=True, + ) + + # Delete existing invites for this email + OrganizationInviteManager.delete_existing_invites( + organization_id=self.context["organization_id"], target_email=validated_data["target_email"] + ) + + # Create new invite invite: OrganizationInvite = OrganizationInvite.objects.create( organization_id=self.context["organization_id"], created_by=self.context["request"].user, **validated_data, ) + if is_email_available(with_absolute_urls=True) and send_email: invite.emailing_attempt_made = True send_invite(invite_id=invite.id) diff --git a/posthog/api/test/test_organization_invites.py b/posthog/api/test/test_organization_invites.py index a486d796a9472..017fc5d05720e 100644 --- a/posthog/api/test/test_organization_invites.py +++ b/posthog/api/test/test_organization_invites.py @@ -2,6 +2,7 @@ from unittest.mock import ANY, patch from django.core import mail +from freezegun import freeze_time from rest_framework import status from ee.models.explicit_team_membership import ExplicitTeamMembership @@ -156,18 +157,18 @@ def test_add_organization_invite_with_email_on_instance_but_send_email_prop_fals # Assert invite email is not sent self.assertEqual(len(mail.outbox), 0) - def test_can_create_invites_for_the_same_email_multiple_times(self): + def test_create_invites_for_the_same_email_multiple_times_deletes_older_invites(self): email = "x@posthog.com" count = OrganizationInvite.objects.count() - for _ in range(0, 2): + for _ in range(0, 3): response = self.client.post("/api/organizations/@current/invites/", {"target_email": email}) self.assertEqual(response.status_code, status.HTTP_201_CREATED) obj = OrganizationInvite.objects.get(id=response.json()["id"]) self.assertEqual(obj.target_email, email) self.assertEqual(obj.created_by, self.user) - self.assertEqual(OrganizationInvite.objects.count(), count + 2) + self.assertEqual(OrganizationInvite.objects.count(), count + 1) def test_can_specify_membership_level_in_invite(self): email = "x@posthog.com" @@ -508,3 +509,148 @@ def test_delete_organization_invite_if_plain_member(self): self.assertEqual(response.content, b"") # Empty response self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) self.assertFalse(OrganizationInvite.objects.exists()) + + # Combine pending invites + + def test_combine_pending_invites_combines_levels_and_project_access(self): + email = "x@posthog.com" + private_team_1 = Team.objects.create(organization=self.organization, name="Private Team 1", access_control=True) + private_team_2 = Team.objects.create(organization=self.organization, name="Private Team 2", access_control=True) + + ExplicitTeamMembership.objects.create( + team=private_team_1, + parent_membership=self.organization_membership, + level=ExplicitTeamMembership.Level.ADMIN, + ) + ExplicitTeamMembership.objects.create( + team=private_team_2, + parent_membership=self.organization_membership, + level=ExplicitTeamMembership.Level.ADMIN, + ) + + # Create first invite with member access to team 1 + first_invite = self.client.post( + "/api/organizations/@current/invites/", + { + "target_email": email, + "level": OrganizationMembership.Level.MEMBER, + "private_project_access": [{"id": private_team_1.id, "level": ExplicitTeamMembership.Level.MEMBER}], + }, + ).json() + + # Create second invite with admin access to team 2 + second_invite = self.client.post( + "/api/organizations/@current/invites/", + { + "target_email": email, + "level": OrganizationMembership.Level.ADMIN, + "private_project_access": [{"id": private_team_2.id, "level": ExplicitTeamMembership.Level.ADMIN}], + }, + ).json() + + # Create third invite combining previous invites + response = self.client.post( + "/api/organizations/@current/invites/", + { + "target_email": email, + "level": OrganizationMembership.Level.MEMBER, + "private_project_access": [{"id": private_team_1.id, "level": ExplicitTeamMembership.Level.ADMIN}], + "combine_pending_invites": True, + }, + ) + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + combined_invite = response.json() + + # Check that previous invites are deleted + self.assertFalse(OrganizationInvite.objects.filter(id=first_invite["id"]).exists()) + self.assertFalse(OrganizationInvite.objects.filter(id=second_invite["id"]).exists()) + + # Check that the new invite has the highest level (ADMIN) + self.assertEqual(combined_invite["level"], OrganizationMembership.Level.ADMIN) + + # Check that private project access is combined with highest levels + expected_access = [ + {"id": private_team_1.id, "level": ExplicitTeamMembership.Level.ADMIN}, + {"id": private_team_2.id, "level": ExplicitTeamMembership.Level.ADMIN}, + ] + self.assertEqual(len(combined_invite["private_project_access"]), 2) + for access in expected_access: + self.assertIn(access, combined_invite["private_project_access"]) + + def test_combine_pending_invites_with_no_existing_invites(self): + email = "x@posthog.com" + response = self.client.post( + "/api/organizations/@current/invites/", + { + "target_email": email, + "level": OrganizationMembership.Level.MEMBER, + "combine_pending_invites": True, + }, + ) + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + invite = response.json() + self.assertEqual(invite["level"], OrganizationMembership.Level.MEMBER) + self.assertEqual(invite["target_email"], email) + self.assertEqual(invite["private_project_access"], []) + + @freeze_time("2024-01-10") + def test_combine_pending_invites_with_expired_invites(self): + email = "xyz@posthog.com" + + # Create an expired invite + with freeze_time("2023-01-05"): + OrganizationInvite.objects.create( + organization=self.organization, + target_email=email, + level=OrganizationMembership.Level.ADMIN, + ) + + response = self.client.post( + "/api/organizations/@current/invites/", + { + "target_email": email, + "level": OrganizationMembership.Level.MEMBER, + "combine_pending_invites": True, + }, + ) + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + invite = response.json() + + # Check that the new invite uses its own level, not the expired invite's level + self.assertEqual(invite["level"], OrganizationMembership.Level.MEMBER) + self.assertEqual(invite["target_email"], email) + self.assertEqual(invite["private_project_access"], []) + + def test_combine_pending_invites_false_expires_existing_invites(self): + email = "x@posthog.com" + + # Create first invite + first_invite = self.client.post( + "/api/organizations/@current/invites/", + { + "target_email": email, + "level": OrganizationMembership.Level.ADMIN, + }, + ).json() + + # Create second invite with combine_pending_invites=False + response = self.client.post( + "/api/organizations/@current/invites/", + { + "target_email": email, + "level": OrganizationMembership.Level.MEMBER, + "combine_pending_invites": False, + }, + ) + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + new_invite = response.json() + + # Check that previous invite is deleted + self.assertFalse(OrganizationInvite.objects.filter(id=first_invite["id"]).exists()) + + # Check that new invite uses its own level + self.assertEqual(new_invite["level"], OrganizationMembership.Level.MEMBER)