diff --git a/ee/billing/billing_manager.py b/ee/billing/billing_manager.py index 904c6b63b9f17..24e303c37ec52 100644 --- a/ee/billing/billing_manager.py +++ b/ee/billing/billing_manager.py @@ -1,4 +1,5 @@ from django.conf import settings +from django.db.models import F from datetime import datetime, timedelta from enum import Enum from typing import Any, Optional, cast @@ -7,6 +8,7 @@ import requests import structlog from django.utils import timezone +from sentry_sdk import capture_message from requests import JSONDecodeError # type: ignore[attr-defined] from rest_framework.exceptions import NotAuthenticated from sentry_sdk import capture_exception @@ -126,26 +128,50 @@ def update_billing(self, organization: Organization, data: dict[str, Any]) -> No handle_billing_service_error(res) - def update_billing_distinct_ids(self, organization: Organization) -> None: - distinct_ids = list(organization.members.values_list("distinct_id", flat=True)) - self.update_billing(organization, {"distinct_ids": distinct_ids}) - - def update_billing_customer_email(self, organization: Organization) -> None: + def update_billing_organization_users(self, organization: Organization) -> None: try: - owner_membership = OrganizationMembership.objects.get(organization=organization, level=15) - user = owner_membership.user - self.update_billing(organization, {"org_customer_email": user.email}) - except Exception as e: - capture_exception(e) + distinct_ids = list(organization.members.values_list("distinct_id", flat=True)) + + first_owner_membership = ( + OrganizationMembership.objects.filter(organization=organization, level=15) + .order_by("-joined_at") + .first() + ) + if not first_owner_membership: + capture_message(f"No owner membership found for organization {organization.id}") + return + first_owner = first_owner_membership.user - def update_billing_admin_emails(self, organization: Organization) -> None: - try: admin_emails = list( organization.members.filter( organization_membership__level__gte=OrganizationMembership.Level.ADMIN ).values_list("email", flat=True) ) - self.update_billing(organization, {"org_admin_emails": admin_emails}) + + org_users = list( + organization.members.values( + "email", + "distinct_id", + "organization_membership__level", + ) + .annotate(role=F("organization_membership__level")) + .filter(role__gte=OrganizationMembership.Level.ADMIN) + .values( + "email", + "distinct_id", + "role", + ) + ) + + self.update_billing( + organization, + { + "distinct_ids": distinct_ids, + "org_customer_email": first_owner.email, + "org_admin_emails": admin_emails, + "org_users": org_users, + }, + ) except Exception as e: capture_exception(e) diff --git a/ee/billing/test/test_billing_manager.py b/ee/billing/test/test_billing_manager.py index 17e9a4299dd59..1c5ed9d2ba859 100644 --- a/ee/billing/test/test_billing_manager.py +++ b/ee/billing/test/test_billing_manager.py @@ -58,37 +58,14 @@ def test_get_billing_unlicensed(self, billing_patch_request_mock): "ee.billing.billing_manager.requests.patch", return_value=MagicMock(status_code=200, json=MagicMock(return_value={"text": "ok"})), ) - def test_update_billing_distinct_ids(self, billing_patch_request_mock: MagicMock): + def test_update_billing_organization_users(self, billing_patch_request_mock: MagicMock): organization = self.organization license = super(LicenseManager, cast(LicenseManager, License.objects)).create( key="key123::key123", plan="enterprise", valid_until=timezone.datetime(2038, 1, 19, 3, 14, 7), ) - User.objects.create_and_join( - organization=organization, - email="y@x.com", - password=None, - level=OrganizationMembership.Level.ADMIN, - ) - organization.refresh_from_db() - assert len(organization.members.values_list("distinct_id", flat=True)) == 2 # one exists in the test base - BillingManager(license).update_billing_distinct_ids(organization) - assert billing_patch_request_mock.call_count == 1 - assert len(billing_patch_request_mock.call_args[1]["json"]["distinct_ids"]) == 2 - - @patch( - "ee.billing.billing_manager.requests.patch", - return_value=MagicMock(status_code=200, json=MagicMock(return_value={"text": "ok"})), - ) - def test_update_billing_customer_email(self, billing_patch_request_mock: MagicMock): - organization = self.organization - license = super(LicenseManager, cast(LicenseManager, License.objects)).create( - key="key123::key123", - plan="enterprise", - valid_until=timezone.datetime(2038, 1, 19, 3, 14, 7), - ) - User.objects.create_and_join( + y = User.objects.create_and_join( organization=organization, email="y@x.com", password=None, @@ -96,15 +73,20 @@ def test_update_billing_customer_email(self, billing_patch_request_mock: MagicMo ) organization.refresh_from_db() assert len(organization.members.values_list("distinct_id", flat=True)) == 2 # one exists in the test base - BillingManager(license).update_billing_customer_email(organization) + BillingManager(license).update_billing_organization_users(organization) assert billing_patch_request_mock.call_count == 1 + assert len(billing_patch_request_mock.call_args[1]["json"]["distinct_ids"]) == 2 assert billing_patch_request_mock.call_args[1]["json"]["org_customer_email"] == "y@x.com" + assert billing_patch_request_mock.call_args[1]["json"]["org_admin_emails"] == ["y@x.com"] + assert billing_patch_request_mock.call_args[1]["json"]["org_users"] == [ + {"email": "y@x.com", "distinct_id": y.distinct_id, "role": 15}, + ] @patch( "ee.billing.billing_manager.requests.patch", return_value=MagicMock(status_code=200, json=MagicMock(return_value={"text": "ok"})), ) - def test_update_billing_admin_emails(self, billing_patch_request_mock: MagicMock): + def test_update_billing_organization_users_with_multiple_members(self, billing_patch_request_mock: MagicMock): organization = self.organization license = super(LicenseManager, cast(LicenseManager, License.objects)).create( key="key123::key123", @@ -114,22 +96,32 @@ def test_update_billing_admin_emails(self, billing_patch_request_mock: MagicMock User.objects.create_and_join( organization=organization, email="y1@x.com", + first_name="y1", + last_name="y1", password=None, level=OrganizationMembership.Level.MEMBER, ) - User.objects.create_and_join( + y2 = User.objects.create_and_join( organization=organization, email="y2@x.com", + first_name="y2", + last_name="y2", password=None, level=OrganizationMembership.Level.ADMIN, ) - User.objects.create_and_join( + y3 = User.objects.create_and_join( organization=organization, email="y3@x.com", password=None, level=OrganizationMembership.Level.OWNER, ) organization.refresh_from_db() - BillingManager(license).update_billing_admin_emails(organization) + BillingManager(license).update_billing_organization_users(organization) assert billing_patch_request_mock.call_count == 1 + assert len(billing_patch_request_mock.call_args[1]["json"]["distinct_ids"]) == 4 + assert billing_patch_request_mock.call_args[1]["json"]["org_customer_email"] == "y3@x.com" assert sorted(billing_patch_request_mock.call_args[1]["json"]["org_admin_emails"]) == ["y2@x.com", "y3@x.com"] + assert billing_patch_request_mock.call_args[1]["json"]["org_users"] == [ + {"email": "y2@x.com", "distinct_id": y2.distinct_id, "role": 8}, + {"email": "y3@x.com", "distinct_id": y3.distinct_id, "role": 15}, + ] diff --git a/posthog/api/organization_member.py b/posthog/api/organization_member.py index 2bb4def355f94..84a54276e875b 100644 --- a/posthog/api/organization_member.py +++ b/posthog/api/organization_member.py @@ -79,7 +79,7 @@ def update(self, updated_membership, validated_data, **kwargs): setattr(updated_membership, attr, value) updated_membership.save() if level_changed: - self.context["request"].user.update_billing_admin_emails(updated_membership.organization) + self.context["request"].user.update_billing_organization_users(updated_membership.organization) return updated_membership diff --git a/posthog/api/test/test_organization_members.py b/posthog/api/test/test_organization_members.py index bd2fe692209ad..c122e41794ac7 100644 --- a/posthog/api/test/test_organization_members.py +++ b/posthog/api/test/test_organization_members.py @@ -50,8 +50,8 @@ def test_cant_list_members_for_an_alien_organization(self): self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.json(), self.permission_denied_response()) - @patch("posthog.models.user.User.update_billing_admin_emails") - def test_delete_organization_member(self, mock_update_billing_admin_emails): + @patch("posthog.models.user.User.update_billing_organization_users") + def test_delete_organization_member(self, mock_update_billing_organization_users): user = User.objects.create_and_join(self.organization, "test@x.com", None, "X") membership_queryset = OrganizationMembership.objects.filter(user=user, organization=self.organization) self.assertTrue(membership_queryset.exists()) @@ -66,26 +66,27 @@ def test_delete_organization_member(self, mock_update_billing_admin_emails): self.assertEqual(response.status_code, 204) self.assertFalse(membership_queryset.exists(), False) - assert mock_update_billing_admin_emails.call_count == 1 - assert mock_update_billing_admin_emails.call_args_list == [ + assert mock_update_billing_organization_users.call_count == 2 + assert mock_update_billing_organization_users.call_args_list == [ + call(self.organization), call(self.organization), ] - @patch("posthog.models.user.User.update_billing_admin_emails") - def test_leave_organization(self, mock_update_billing_admin_emails): + @patch("posthog.models.user.User.update_billing_organization_users") + def test_leave_organization(self, mock_update_billing_organization_users): membership_queryset = OrganizationMembership.objects.filter(user=self.user, organization=self.organization) self.assertEqual(membership_queryset.count(), 1) response = self.client.delete(f"/api/organizations/@current/members/{self.user.uuid}/") self.assertEqual(response.status_code, 204) self.assertEqual(membership_queryset.count(), 0) - assert mock_update_billing_admin_emails.call_count == 1 - assert mock_update_billing_admin_emails.call_args_list == [ + assert mock_update_billing_organization_users.call_count == 1 + assert mock_update_billing_organization_users.call_args_list == [ call(self.organization), ] - @patch("posthog.models.user.User.update_billing_admin_emails") - def test_change_organization_member_level(self, mock_update_billing_admin_emails): + @patch("posthog.models.user.User.update_billing_organization_users") + def test_change_organization_member_level(self, mock_update_billing_organization_users): self.organization_membership.level = OrganizationMembership.Level.OWNER self.organization_membership.save() user = User.objects.create_user("test@x.com", None, "X") @@ -120,13 +121,13 @@ def test_change_organization_member_level(self, mock_update_billing_admin_emails "level": OrganizationMembership.Level.ADMIN.value, }, ) - assert mock_update_billing_admin_emails.call_count == 1 - assert mock_update_billing_admin_emails.call_args_list == [ + assert mock_update_billing_organization_users.call_count == 1 + assert mock_update_billing_organization_users.call_args_list == [ call(self.organization), ] - @patch("posthog.models.user.User.update_billing_admin_emails") - def test_admin_can_promote_to_admin(self, mock_update_billing_admin_emails): + @patch("posthog.models.user.User.update_billing_organization_users") + def test_admin_can_promote_to_admin(self, mock_update_billing_organization_users): self.organization_membership.level = OrganizationMembership.Level.ADMIN self.organization_membership.save() user = User.objects.create_user("test@x.com", None, "X") @@ -140,13 +141,13 @@ def test_admin_can_promote_to_admin(self, mock_update_billing_admin_emails): updated_membership = OrganizationMembership.objects.get(user=user, organization=self.organization) self.assertEqual(updated_membership.level, OrganizationMembership.Level.ADMIN) - assert mock_update_billing_admin_emails.call_count == 1 - assert mock_update_billing_admin_emails.call_args_list == [ + assert mock_update_billing_organization_users.call_count == 1 + assert mock_update_billing_organization_users.call_args_list == [ call(self.organization), ] - @patch("posthog.models.user.User.update_billing_admin_emails") - def test_change_organization_member_level_requires_admin(self, mock_update_billing_admin_emails): + @patch("posthog.models.user.User.update_billing_organization_users") + def test_change_organization_member_level_requires_admin(self, mock_update_billing_organization_users): user = User.objects.create_user("test@x.com", None, "X") membership = OrganizationMembership.objects.create(user=user, organization=self.organization) self.assertEqual(membership.level, OrganizationMembership.Level.MEMBER) @@ -168,7 +169,7 @@ def test_change_organization_member_level_requires_admin(self, mock_update_billi ) self.assertEqual(response.status_code, 403) - assert mock_update_billing_admin_emails.call_count == 0 + assert mock_update_billing_organization_users.call_count == 0 def test_cannot_change_own_organization_member_level(self): self.organization_membership.level = OrganizationMembership.Level.ADMIN diff --git a/posthog/api/test/test_signup.py b/posthog/api/test/test_signup.py index 919f029787483..d329f3de75d83 100644 --- a/posthog/api/test/test_signup.py +++ b/posthog/api/test/test_signup.py @@ -695,9 +695,7 @@ def test_social_signup_with_allowed_domain_on_self_hosted( self.run_test_for_allowed_domain(mock_sso_providers, mock_request, mock_capture) @patch("posthoganalytics.capture") - @mock.patch("ee.billing.billing_manager.BillingManager.update_billing_distinct_ids") - @mock.patch("ee.billing.billing_manager.BillingManager.update_billing_customer_email") - @mock.patch("ee.billing.billing_manager.BillingManager.update_billing_admin_emails") + @mock.patch("ee.billing.billing_manager.BillingManager.update_billing_organization_users") @mock.patch("social_core.backends.base.BaseAuth.request") @mock.patch("posthog.api.authentication.get_instance_available_sso_providers") @mock.patch("posthog.tasks.user_identify.identify_task") @@ -707,21 +705,15 @@ def test_social_signup_with_allowed_domain_on_cloud( mock_identify, mock_sso_providers, mock_request, - mock_update_distinct_ids, - mock_update_billing_customer_email, - mock_update_billing_admin_emails, + mock_update_billing_organization_users, mock_capture, ): with self.is_cloud(True): self.run_test_for_allowed_domain(mock_sso_providers, mock_request, mock_capture) - assert mock_update_distinct_ids.called_once() - assert mock_update_billing_customer_email.called_once() - assert mock_update_billing_admin_emails.called_once() + assert mock_update_billing_organization_users.called_once() @patch("posthoganalytics.capture") - @mock.patch("ee.billing.billing_manager.BillingManager.update_billing_distinct_ids") - @mock.patch("ee.billing.billing_manager.BillingManager.update_billing_customer_email") - @mock.patch("ee.billing.billing_manager.BillingManager.update_billing_admin_emails") + @mock.patch("ee.billing.billing_manager.BillingManager.update_billing_organization_users") @mock.patch("social_core.backends.base.BaseAuth.request") @mock.patch("posthog.api.authentication.get_instance_available_sso_providers") @mock.patch("posthog.tasks.user_identify.identify_task") @@ -731,16 +723,12 @@ def test_social_signup_with_allowed_domain_on_cloud_with_existing_invite( mock_identify, mock_sso_providers, mock_request, - mock_update_distinct_ids, - mock_update_billing_customer_email, - mock_update_billing_admin_emails, + mock_update_billing_organization_users, mock_capture, ): with self.is_cloud(True): self.run_test_for_allowed_domain(mock_sso_providers, mock_request, mock_capture, use_invite=True) - assert mock_update_distinct_ids.called_once() - assert mock_update_billing_customer_email.called_once() - assert mock_update_billing_admin_emails.called_once() + assert mock_update_billing_organization_users.called_once() @mock.patch("social_core.backends.base.BaseAuth.request") @mock.patch("posthog.api.authentication.get_instance_available_sso_providers") @@ -1285,8 +1273,10 @@ def test_api_invite_sign_up_member_joined_email_is_not_sent_if_disabled(self): self.assertEqual(len(mail.outbox), 0) @patch("posthoganalytics.capture") - @patch("ee.billing.billing_manager.BillingManager.update_billing_distinct_ids") - def test_existing_user_can_sign_up_to_a_new_organization(self, mock_update_distinct_ids, mock_capture): + @patch("ee.billing.billing_manager.BillingManager.update_billing_organization_users") + def test_existing_user_can_sign_up_to_a_new_organization( + self, mock_update_billing_organization_users, mock_capture + ): user = self._create_user("test+159@posthog.com", VALID_TEST_PASSWORD) new_org = Organization.objects.create(name="TestCo") new_team = Team.objects.create(organization=new_org) @@ -1364,7 +1354,7 @@ def test_existing_user_can_sign_up_to_a_new_organization(self, mock_update_disti self.assertEqual(response.status_code, status.HTTP_200_OK) # Assert that the org's distinct IDs are sent to billing - mock_update_distinct_ids.assert_called_once_with(new_org) + mock_update_billing_organization_users.assert_called_once_with(new_org) @patch("posthoganalytics.capture") def test_cannot_use_claim_invite_endpoint_to_update_user(self, mock_capture): diff --git a/posthog/management/commands/sync_to_billing.py b/posthog/management/commands/sync_to_billing.py index 77a06a1fe77ce..1fd21b59e79c3 100644 --- a/posthog/management/commands/sync_to_billing.py +++ b/posthog/management/commands/sync_to_billing.py @@ -18,7 +18,7 @@ def handle(self, *args, **options): action = options["action"] organization_ids = options["organization_ids"] - if action not in ["distinct_ids", "admin_emails", "customer_email"]: + if action not in ["organization_users"]: print("Invalid action, please select 'distinct_ids', 'admin_emails' or 'customer_email'") # noqa T201 return @@ -36,12 +36,8 @@ def handle(self, *args, **options): if not first_owner: print(f"Organization {organization.id} has no owner") # noqa T201 - if action == "distinct_ids": - first_owner.update_billing_distinct_ids(organization) - elif action == "admin_emails": - first_owner.update_billing_admin_emails(organization) - elif action == "customer_email": - first_owner.update_billing_customer_email(organization) + if action == "organization_users": + first_owner.update_billing_organization_users(organization) if index % 50 == 0: print(f"Processed {index} organizations out of {len(organizations)}") # noqa T201 diff --git a/posthog/models/user.py b/posthog/models/user.py index 466f4ea1c0985..f896ab7bb9a6b 100644 --- a/posthog/models/user.py +++ b/posthog/models/user.py @@ -246,11 +246,7 @@ def join( # We don't need to check for ExplicitTeamMembership as none can exist for a completely new member self.current_team = organization.teams.order_by("id").filter(access_control=False).first() self.save() - if level == OrganizationMembership.Level.OWNER and not self.current_organization.customer_id: - self.update_billing_customer_email(organization) - if level >= OrganizationMembership.Level.ADMIN: - self.update_billing_admin_emails(organization) - self.update_billing_distinct_ids(organization) + self.update_billing_organization_users(organization) return membership @property @@ -273,26 +269,13 @@ def leave(self, *, organization: Organization) -> None: ) self.team = self.current_team # Update cached property self.save() - self.update_billing_admin_emails(organization) - self.update_billing_distinct_ids(organization) + self.update_billing_organization_users(organization) - def update_billing_distinct_ids(self, organization: Organization) -> None: + def update_billing_organization_users(self, organization: Organization) -> None: from ee.billing.billing_manager import BillingManager # avoid circular import if is_cloud() and get_cached_instance_license() is not None: - BillingManager(get_cached_instance_license()).update_billing_distinct_ids(organization) - - def update_billing_customer_email(self, organization: Organization) -> None: - from ee.billing.billing_manager import BillingManager # avoid circular import - - if is_cloud() and get_cached_instance_license() is not None: - BillingManager(get_cached_instance_license()).update_billing_customer_email(organization) - - def update_billing_admin_emails(self, organization: Organization) -> None: - from ee.billing.billing_manager import BillingManager - - if is_cloud() and get_cached_instance_license() is not None: - BillingManager(get_cached_instance_license()).update_billing_admin_emails(organization) + BillingManager(get_cached_instance_license()).update_billing_organization_users(organization) def get_analytics_metadata(self): team_member_count_all: int = (