From ecd525644db1decb37c2944fd4373d73c7eac88f Mon Sep 17 00:00:00 2001
From: Bianca Yang <ipacifics@gmail.com>
Date: Tue, 21 Nov 2023 10:39:37 -0800
Subject: [PATCH] feat: Update the organization customer email upon signup
 (#18724)

* Update the organization customer email upon signup

* Only do this is the person signing up is the owner of the org and there isn't already a
  stripe subscription associated with the org

* undo local changes that shouldn't be in prod

* small tweaks to match backend

* fix test

* Update query snapshots

* Update query snapshots

---------

Co-authored-by: Bianca Yang <bianca@posthog.com>
Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
---
 ee/api/test/test_billing.py             |  2 +-
 ee/billing/billing_manager.py           | 11 ++++++++++-
 ee/billing/test/test_billing_manager.py | 23 +++++++++++++++++++++++
 posthog/api/test/test_signup.py         |  5 ++++-
 posthog/models/user.py                  | 19 +++++++++----------
 5 files changed, 47 insertions(+), 13 deletions(-)

diff --git a/ee/api/test/test_billing.py b/ee/api/test/test_billing.py
index 88addd2d7f416..c37c3ee9d6482 100644
--- a/ee/api/test/test_billing.py
+++ b/ee/api/test/test_billing.py
@@ -2,9 +2,9 @@
 from typing import Any, Dict, List
 from unittest.mock import MagicMock, patch
 from uuid import uuid4
+from zoneinfo import ZoneInfo
 
 import jwt
-from zoneinfo import ZoneInfo
 from dateutil.relativedelta import relativedelta
 from django.utils.timezone import now
 from freezegun import freeze_time
diff --git a/ee/billing/billing_manager.py b/ee/billing/billing_manager.py
index 5a8119c57df9b..324b158fe071d 100644
--- a/ee/billing/billing_manager.py
+++ b/ee/billing/billing_manager.py
@@ -6,6 +6,7 @@
 import structlog
 from django.utils import timezone
 from rest_framework.exceptions import NotAuthenticated
+from sentry_sdk import capture_exception
 
 from ee.billing.billing_types import BillingStatus
 from ee.billing.quota_limiting import set_org_usage_summary, sync_org_quota_limits
@@ -13,7 +14,7 @@
 from ee.settings import BILLING_SERVICE_URL
 from posthog.cloud_utils import get_cached_instance_license
 from posthog.models import Organization
-from posthog.models.organization import OrganizationUsageInfo
+from posthog.models.organization import OrganizationMembership, OrganizationUsageInfo
 
 logger = structlog.get_logger(__name__)
 
@@ -114,6 +115,14 @@ def update_billing_distinct_ids(self, organization: Organization) -> None:
         distinct_ids = list(organization.members.values_list("distinct_id", flat=True))
         self.update_billing(organization, {"distinct_ids": distinct_ids})
 
+    def update_billing_customer_email(self, organization: Organization) -> None:
+        try:
+            owner_membership = OrganizationMembership.objects.get(organization=organization, level=15)
+            user = owner_membership.user
+            self.update_billing(organization, {"org_customer_email": user.email})
+        except Exception as e:
+            capture_exception(e)
+
     def deactivate_products(self, organization: Organization, products: str) -> None:
         res = requests.get(
             f"{BILLING_SERVICE_URL}/api/billing/deactivate?products={products}",
diff --git a/ee/billing/test/test_billing_manager.py b/ee/billing/test/test_billing_manager.py
index e0c09e0d071fb..1dbbcb464f068 100644
--- a/ee/billing/test/test_billing_manager.py
+++ b/ee/billing/test/test_billing_manager.py
@@ -33,3 +33,26 @@ def test_update_billing_distinct_ids(self, billing_patch_request_mock: MagicMock
         BillingManager(license).update_billing_distinct_ids(organization)
         assert billing_patch_request_mock.call_count == 1
         assert len(billing_patch_request_mock.call_args[1]["json"]["distinct_ids"]) == 2
+
+    @patch(
+        "ee.billing.billing_manager.requests.patch",
+        return_value=MagicMock(status_code=200, json=MagicMock(return_value={"text": "ok"})),
+    )
+    def test_update_billing_customer_email(self, billing_patch_request_mock: MagicMock):
+        organization = self.organization
+        license = super(LicenseManager, cast(LicenseManager, License.objects)).create(
+            key="key123::key123",
+            plan="enterprise",
+            valid_until=timezone.datetime(2038, 1, 19, 3, 14, 7),
+        )
+        User.objects.create_and_join(
+            organization=organization,
+            email="y@x.com",
+            password=None,
+            level=OrganizationMembership.Level.OWNER,
+        )
+        organization.refresh_from_db()
+        assert len(organization.members.values_list("distinct_id", flat=True)) == 2  # one exists in the test base
+        BillingManager(license).update_billing_customer_email(organization)
+        assert billing_patch_request_mock.call_count == 1
+        assert billing_patch_request_mock.call_args[1]["json"]["org_customer_email"] == "y@x.com"
diff --git a/posthog/api/test/test_signup.py b/posthog/api/test/test_signup.py
index e106dd6cbddf2..00c101e4487ee 100644
--- a/posthog/api/test/test_signup.py
+++ b/posthog/api/test/test_signup.py
@@ -3,9 +3,9 @@
 from typing import Dict, Optional, cast
 from unittest import mock
 from unittest.mock import ANY, patch
+from zoneinfo import ZoneInfo
 
 import pytest
-from zoneinfo import ZoneInfo
 from django.core import mail
 from django.urls.base import reverse
 from django.utils import timezone
@@ -543,6 +543,7 @@ def test_social_signup_with_allowed_domain_on_self_hosted(
 
     @patch("posthoganalytics.capture")
     @mock.patch("ee.billing.billing_manager.BillingManager.update_billing_distinct_ids")
+    @mock.patch("ee.billing.billing_manager.BillingManager.update_billing_customer_email")
     @mock.patch("social_core.backends.base.BaseAuth.request")
     @mock.patch("posthog.api.authentication.get_instance_available_sso_providers")
     @mock.patch("posthog.tasks.user_identify.identify_task")
@@ -553,11 +554,13 @@ def test_social_signup_with_allowed_domain_on_cloud(
         mock_sso_providers,
         mock_request,
         mock_update_distinct_ids,
+        mock_update_billing_customer_email,
         mock_capture,
     ):
         with self.is_cloud(True):
             self.run_test_for_allowed_domain(mock_sso_providers, mock_request, mock_capture)
         assert mock_update_distinct_ids.called_once()
+        assert mock_update_billing_customer_email.called_once()
 
     @mock.patch("social_core.backends.base.BaseAuth.request")
     @mock.patch("posthog.api.authentication.get_instance_available_sso_providers")
diff --git a/posthog/models/user.py b/posthog/models/user.py
index 423936747e2cc..353d20ae31d9c 100644
--- a/posthog/models/user.py
+++ b/posthog/models/user.py
@@ -1,14 +1,5 @@
 from functools import cached_property
-from typing import (
-    Any,
-    Callable,
-    Dict,
-    List,
-    Optional,
-    Tuple,
-    Type,
-    TypedDict,
-)
+from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypedDict
 
 from django.contrib.auth.models import AbstractUser, BaseUserManager
 from django.db import models, transaction
@@ -237,6 +228,8 @@ def join(
                 # We don't need to check for ExplicitTeamMembership as none can exist for a completely new member
                 self.current_team = organization.teams.order_by("id").filter(access_control=False).first()
             self.save()
+        if level == OrganizationMembership.Level.OWNER and not self.current_organization.customer_id:
+            self.update_billing_customer_email(organization)
         self.update_billing_distinct_ids(organization)
         return membership
 
@@ -268,6 +261,12 @@ def update_billing_distinct_ids(self, organization: Organization) -> None:
         if is_cloud() and get_cached_instance_license() is not None:
             BillingManager(get_cached_instance_license()).update_billing_distinct_ids(organization)
 
+    def update_billing_customer_email(self, organization: Organization) -> None:
+        from ee.billing.billing_manager import BillingManager  # avoid circular import
+
+        if is_cloud() and get_cached_instance_license() is not None:
+            BillingManager(get_cached_instance_license()).update_billing_customer_email(organization)
+
     def get_analytics_metadata(self):
         team_member_count_all: int = (
             OrganizationMembership.objects.filter(organization__in=self.organizations.all())