-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: add invite expired catch for saml provisioning (#25118)
- Loading branch information
1 parent
aa94ec5
commit c2778a3
Showing
4 changed files
with
43 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
from unittest import mock | ||
from unittest.mock import ANY, patch | ||
from zoneinfo import ZoneInfo | ||
from datetime import timedelta | ||
|
||
import pytest | ||
from django.core import mail | ||
|
@@ -615,7 +616,9 @@ def test_api_social_login_cannot_create_second_organization(self, mock_sso_provi | |
response, "/login?error_code=no_new_organizations" | ||
) # show the user an error; operation not permitted | ||
|
||
def run_test_for_allowed_domain(self, mock_sso_providers, mock_request, mock_capture, use_invite: bool = False): | ||
def run_test_for_allowed_domain( | ||
self, mock_sso_providers, mock_request, mock_capture, use_invite: bool = False, expired_invite: bool = False | ||
): | ||
# Make sure Google Auth is valid for this test instance | ||
mock_sso_providers.return_value = {"google-oauth2": True} | ||
|
||
|
@@ -632,13 +635,17 @@ def run_test_for_allowed_domain(self, mock_sso_providers, mock_request, mock_cap | |
private_project: Team = Team.objects.create( | ||
organization=new_org, name="Private Project", access_control=True | ||
) | ||
OrganizationInvite.objects.create( | ||
invite = OrganizationInvite.objects.create( | ||
target_email="[email protected]", | ||
organization=new_org, | ||
first_name="Jane", | ||
level=OrganizationMembership.Level.MEMBER, | ||
private_project_access=[{"id": private_project.id, "level": ExplicitTeamMembership.Level.ADMIN}], | ||
) | ||
if expired_invite: | ||
invite.created_at = timezone.now() - timedelta(days=30) # Set invite to 30 days old | ||
invite.save() | ||
|
||
user_count = User.objects.count() | ||
response = self.client.get(reverse("social:begin", kwargs={"backend": "google-oauth2"})) | ||
self.assertEqual(response.status_code, status.HTTP_302_FOUND) | ||
|
@@ -667,7 +674,7 @@ def run_test_for_allowed_domain(self, mock_sso_providers, mock_request, mock_cap | |
) | ||
self.assertFalse(mock_capture.call_args.kwargs["properties"]["is_organization_first_user"]) | ||
|
||
if use_invite: | ||
if use_invite and not expired_invite: | ||
# make sure the org invite no longer exists | ||
self.assertEqual( | ||
OrganizationInvite.objects.filter( | ||
|
@@ -684,6 +691,10 @@ def run_test_for_allowed_domain(self, mock_sso_providers, mock_request, mock_cap | |
) | ||
assert explicit_team_membership.level == ExplicitTeamMembership.Level.ADMIN | ||
|
||
if expired_invite: | ||
# Check that the user was still created and added to the organization | ||
self.assertEqual(user.organization, new_org) | ||
|
||
@patch("posthoganalytics.capture") | ||
@mock.patch("social_core.backends.base.BaseAuth.request") | ||
@mock.patch("posthog.api.authentication.get_instance_available_sso_providers") | ||
|
@@ -730,6 +741,26 @@ def test_social_signup_with_allowed_domain_on_cloud_with_existing_invite( | |
self.run_test_for_allowed_domain(mock_sso_providers, mock_request, mock_capture, use_invite=True) | ||
assert mock_update_billing_organization_users.called_once() | ||
|
||
@patch("posthoganalytics.capture") | ||
@mock.patch("ee.billing.billing_manager.BillingManager.update_billing_organization_users") | ||
@mock.patch("social_core.backends.base.BaseAuth.request") | ||
@mock.patch("posthog.api.authentication.get_instance_available_sso_providers") | ||
@mock.patch("posthog.tasks.user_identify.identify_task") | ||
@pytest.mark.ee | ||
def test_social_signup_with_allowed_domain_on_cloud_with_existing_expired_invite( | ||
self, | ||
mock_identify, | ||
mock_sso_providers, | ||
mock_request, | ||
mock_update_billing_organization_users, | ||
mock_capture, | ||
): | ||
with self.is_cloud(True): | ||
self.run_test_for_allowed_domain( | ||
mock_sso_providers, mock_request, mock_capture, use_invite=True, expired_invite=True | ||
) | ||
assert mock_update_billing_organization_users.called_once() | ||
|
||
@mock.patch("social_core.backends.base.BaseAuth.request") | ||
@mock.patch("posthog.api.authentication.get_instance_available_sso_providers") | ||
@pytest.mark.ee | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters