diff --git a/.vscode/launch.json b/.vscode/launch.json index 51101a38d3d739..ddfb0cf32e5d0f 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -71,7 +71,7 @@ "SKIP_SERVICE_VERSION_REQUIREMENTS": "1", "PRINT_SQL": "1", "REPLAY_EVENTS_NEW_CONSUMER_RATIO": "1.0", - "BILLING_SERVICE_URL": "https://billing.dev.posthog.dev" + "BILLING_SERVICE_URL": "http://localhost:8100" // "https://billing.dev.posthog.dev" }, "console": "integratedTerminal", "python": "${workspaceFolder}/env/bin/python", diff --git a/ee/api/test/test_billing.py b/ee/api/test/test_billing.py index 88addd2d7f4161..c37c3ee9d64821 100644 --- a/ee/api/test/test_billing.py +++ b/ee/api/test/test_billing.py @@ -2,9 +2,9 @@ from typing import Any, Dict, List from unittest.mock import MagicMock, patch from uuid import uuid4 +from zoneinfo import ZoneInfo import jwt -from zoneinfo import ZoneInfo from dateutil.relativedelta import relativedelta from django.utils.timezone import now from freezegun import freeze_time diff --git a/ee/billing/billing_manager.py b/ee/billing/billing_manager.py index 5a8119c57df9b9..ec25b39a64ebb0 100644 --- a/ee/billing/billing_manager.py +++ b/ee/billing/billing_manager.py @@ -6,6 +6,7 @@ import structlog from django.utils import timezone from rest_framework.exceptions import NotAuthenticated +from sentry_sdk import capture_exception from ee.billing.billing_types import BillingStatus from ee.billing.quota_limiting import set_org_usage_summary, sync_org_quota_limits @@ -13,7 +14,7 @@ from ee.settings import BILLING_SERVICE_URL from posthog.cloud_utils import get_cached_instance_license from posthog.models import Organization -from posthog.models.organization import OrganizationUsageInfo +from posthog.models.organization import OrganizationMembership, OrganizationUsageInfo logger = structlog.get_logger(__name__) @@ -114,6 +115,14 @@ def update_billing_distinct_ids(self, organization: Organization) -> None: distinct_ids = list(organization.members.values_list("distinct_id", flat=True)) self.update_billing(organization, {"distinct_ids": distinct_ids}) + def update_billing_customer_email(self, organization: Organization) -> None: + try: + owner_membership = OrganizationMembership.objects.get(organization=organization, level=15) + user = owner_membership.user + self.update_billing(organization, {"org_owner_email": user.email}) + except Exception as e: + capture_exception(e) + def deactivate_products(self, organization: Organization, products: str) -> None: res = requests.get( f"{BILLING_SERVICE_URL}/api/billing/deactivate?products={products}", diff --git a/ee/billing/test/test_billing_manager.py b/ee/billing/test/test_billing_manager.py index e0c09e0d071fb6..3b296545e3bd2f 100644 --- a/ee/billing/test/test_billing_manager.py +++ b/ee/billing/test/test_billing_manager.py @@ -33,3 +33,26 @@ def test_update_billing_distinct_ids(self, billing_patch_request_mock: MagicMock BillingManager(license).update_billing_distinct_ids(organization) assert billing_patch_request_mock.call_count == 1 assert len(billing_patch_request_mock.call_args[1]["json"]["distinct_ids"]) == 2 + + @patch( + "ee.billing.billing_manager.requests.patch", + return_value=MagicMock(status_code=200, json=MagicMock(return_value={"text": "ok"})), + ) + def test_update_billing_customer_email(self, billing_patch_request_mock: MagicMock): + organization = self.organization + license = super(LicenseManager, cast(LicenseManager, License.objects)).create( + key="key123::key123", + plan="enterprise", + valid_until=timezone.datetime(2038, 1, 19, 3, 14, 7), + ) + User.objects.create_and_join( + organization=organization, + email="y@x.com", + password=None, + level=OrganizationMembership.Level.OWNER, + ) + organization.refresh_from_db() + assert len(organization.members.values_list("distinct_id", flat=True)) == 2 # one exists in the test base + BillingManager(license).update_billing_customer_email(organization) + assert billing_patch_request_mock.call_count == 1 + assert billing_patch_request_mock.call_args[1]["json"]["org_owner_email"] == "y@x.com" diff --git a/posthog/api/test/test_signup.py b/posthog/api/test/test_signup.py index e106dd6cbddf26..f33a1ebd9195e9 100644 --- a/posthog/api/test/test_signup.py +++ b/posthog/api/test/test_signup.py @@ -3,9 +3,9 @@ from typing import Dict, Optional, cast from unittest import mock from unittest.mock import ANY, patch +from zoneinfo import ZoneInfo import pytest -from zoneinfo import ZoneInfo from django.core import mail from django.urls.base import reverse from django.utils import timezone @@ -546,6 +546,7 @@ def test_social_signup_with_allowed_domain_on_self_hosted( @mock.patch("social_core.backends.base.BaseAuth.request") @mock.patch("posthog.api.authentication.get_instance_available_sso_providers") @mock.patch("posthog.tasks.user_identify.identify_task") + @mock.patch("ee.billing.billing_manager.BillingManager.update_billing_customer_email") @pytest.mark.ee def test_social_signup_with_allowed_domain_on_cloud( self, @@ -553,11 +554,13 @@ def test_social_signup_with_allowed_domain_on_cloud( mock_sso_providers, mock_request, mock_update_distinct_ids, + mock_update_billing_customer_email, mock_capture, ): with self.is_cloud(True): self.run_test_for_allowed_domain(mock_sso_providers, mock_request, mock_capture) assert mock_update_distinct_ids.called_once() + assert mock_update_billing_customer_email.called_once() @mock.patch("social_core.backends.base.BaseAuth.request") @mock.patch("posthog.api.authentication.get_instance_available_sso_providers") diff --git a/posthog/cloud_utils.py b/posthog/cloud_utils.py index cc0aec67d15ff8..093e7e46dc161e 100644 --- a/posthog/cloud_utils.py +++ b/posthog/cloud_utils.py @@ -14,6 +14,7 @@ # NOTE: This is cached for the lifetime of the instance but this is not an issue as the value is not expected to change def is_cloud(): + return True global is_cloud_cached if not settings.EE_AVAILABLE: diff --git a/posthog/models/user.py b/posthog/models/user.py index 423936747e2cc1..353d20ae31d9c6 100644 --- a/posthog/models/user.py +++ b/posthog/models/user.py @@ -1,14 +1,5 @@ from functools import cached_property -from typing import ( - Any, - Callable, - Dict, - List, - Optional, - Tuple, - Type, - TypedDict, -) +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypedDict from django.contrib.auth.models import AbstractUser, BaseUserManager from django.db import models, transaction @@ -237,6 +228,8 @@ def join( # We don't need to check for ExplicitTeamMembership as none can exist for a completely new member self.current_team = organization.teams.order_by("id").filter(access_control=False).first() self.save() + if level == OrganizationMembership.Level.OWNER and not self.current_organization.customer_id: + self.update_billing_customer_email(organization) self.update_billing_distinct_ids(organization) return membership @@ -268,6 +261,12 @@ def update_billing_distinct_ids(self, organization: Organization) -> None: if is_cloud() and get_cached_instance_license() is not None: BillingManager(get_cached_instance_license()).update_billing_distinct_ids(organization) + def update_billing_customer_email(self, organization: Organization) -> None: + from ee.billing.billing_manager import BillingManager # avoid circular import + + if is_cloud() and get_cached_instance_license() is not None: + BillingManager(get_cached_instance_license()).update_billing_customer_email(organization) + def get_analytics_metadata(self): team_member_count_all: int = ( OrganizationMembership.objects.filter(organization__in=self.organizations.all())