Skip to content

Commit

Permalink
fix: add invite expired catch for saml provisioning (#25118)
Browse files Browse the repository at this point in the history
  • Loading branch information
zlwaterfield authored Sep 24, 2024
1 parent aa94ec5 commit c2778a3
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 9 deletions.
3 changes: 2 additions & 1 deletion posthog/api/signup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Organization,
OrganizationDomain,
OrganizationInvite,
InviteExpiredException,
Team,
User,
)
Expand Down Expand Up @@ -446,7 +447,7 @@ def process_social_domain_jit_provisioning_signup(
message = "Account unable to be created. This account may already exist. Please try again or use different credentials."
raise ValidationError(message, code="unknown", params={"source": "social_create_user"})

except OrganizationInvite.DoesNotExist:
except (OrganizationInvite.DoesNotExist, InviteExpiredException):
user = User.objects.create_and_join(
organization=domain_instance.organization,
email=email,
Expand Down
37 changes: 34 additions & 3 deletions posthog/api/test/test_signup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from unittest import mock
from unittest.mock import ANY, patch
from zoneinfo import ZoneInfo
from datetime import timedelta

import pytest
from django.core import mail
Expand Down Expand Up @@ -615,7 +616,9 @@ def test_api_social_login_cannot_create_second_organization(self, mock_sso_provi
response, "/login?error_code=no_new_organizations"
) # show the user an error; operation not permitted

def run_test_for_allowed_domain(self, mock_sso_providers, mock_request, mock_capture, use_invite: bool = False):
def run_test_for_allowed_domain(
self, mock_sso_providers, mock_request, mock_capture, use_invite: bool = False, expired_invite: bool = False
):
# Make sure Google Auth is valid for this test instance
mock_sso_providers.return_value = {"google-oauth2": True}

Expand All @@ -632,13 +635,17 @@ def run_test_for_allowed_domain(self, mock_sso_providers, mock_request, mock_cap
private_project: Team = Team.objects.create(
organization=new_org, name="Private Project", access_control=True
)
OrganizationInvite.objects.create(
invite = OrganizationInvite.objects.create(
target_email="[email protected]",
organization=new_org,
first_name="Jane",
level=OrganizationMembership.Level.MEMBER,
private_project_access=[{"id": private_project.id, "level": ExplicitTeamMembership.Level.ADMIN}],
)
if expired_invite:
invite.created_at = timezone.now() - timedelta(days=30) # Set invite to 30 days old
invite.save()

user_count = User.objects.count()
response = self.client.get(reverse("social:begin", kwargs={"backend": "google-oauth2"}))
self.assertEqual(response.status_code, status.HTTP_302_FOUND)
Expand Down Expand Up @@ -667,7 +674,7 @@ def run_test_for_allowed_domain(self, mock_sso_providers, mock_request, mock_cap
)
self.assertFalse(mock_capture.call_args.kwargs["properties"]["is_organization_first_user"])

if use_invite:
if use_invite and not expired_invite:
# make sure the org invite no longer exists
self.assertEqual(
OrganizationInvite.objects.filter(
Expand All @@ -684,6 +691,10 @@ def run_test_for_allowed_domain(self, mock_sso_providers, mock_request, mock_cap
)
assert explicit_team_membership.level == ExplicitTeamMembership.Level.ADMIN

if expired_invite:
# Check that the user was still created and added to the organization
self.assertEqual(user.organization, new_org)

@patch("posthoganalytics.capture")
@mock.patch("social_core.backends.base.BaseAuth.request")
@mock.patch("posthog.api.authentication.get_instance_available_sso_providers")
Expand Down Expand Up @@ -730,6 +741,26 @@ def test_social_signup_with_allowed_domain_on_cloud_with_existing_invite(
self.run_test_for_allowed_domain(mock_sso_providers, mock_request, mock_capture, use_invite=True)
assert mock_update_billing_organization_users.called_once()

@patch("posthoganalytics.capture")
@mock.patch("ee.billing.billing_manager.BillingManager.update_billing_organization_users")
@mock.patch("social_core.backends.base.BaseAuth.request")
@mock.patch("posthog.api.authentication.get_instance_available_sso_providers")
@mock.patch("posthog.tasks.user_identify.identify_task")
@pytest.mark.ee
def test_social_signup_with_allowed_domain_on_cloud_with_existing_expired_invite(
self,
mock_identify,
mock_sso_providers,
mock_request,
mock_update_billing_organization_users,
mock_capture,
):
with self.is_cloud(True):
self.run_test_for_allowed_domain(
mock_sso_providers, mock_request, mock_capture, use_invite=True, expired_invite=True
)
assert mock_update_billing_organization_users.called_once()

@mock.patch("social_core.backends.base.BaseAuth.request")
@mock.patch("posthog.api.authentication.get_instance_available_sso_providers")
@pytest.mark.ee
Expand Down
2 changes: 1 addition & 1 deletion posthog/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from .notebook import Notebook
from .organization import Organization, OrganizationMembership
from .organization_domain import OrganizationDomain
from .organization_invite import OrganizationInvite
from .organization_invite import OrganizationInvite, InviteExpiredException
from .person import Person, PersonDistinctId, PersonOverride, PersonOverrideMapping
from .personal_api_key import PersonalAPIKey
from .plugin import (
Expand Down
10 changes: 6 additions & 4 deletions posthog/models/organization_invite.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ def validate_private_project_access(value):
raise exceptions.ValidationError('The "level" field must be either "member" or "admin".')


class InviteExpiredException(exceptions.ValidationError):
def __init__(self, message="This invite has expired. Please ask your admin for a new one."):
super().__init__(message, code="expired")


class OrganizationInvite(UUIDModel):
organization = models.ForeignKey(
"posthog.Organization",
Expand Down Expand Up @@ -85,10 +90,7 @@ def validate(
)

if self.is_expired():
raise exceptions.ValidationError(
"This invite has expired. Please ask your admin for a new one.",
code="expired",
)
raise InviteExpiredException()

if user is None and User.objects.filter(email=invite_email).exists():
raise exceptions.ValidationError(f"/login?next={request_path}", code="account_exists")
Expand Down

0 comments on commit c2778a3

Please sign in to comment.