Skip to content

Commit

Permalink
feat: Update the organization customer email upon signup (#18724)
Browse files Browse the repository at this point in the history
* Update the organization customer email upon signup

* Only do this is the person signing up is the owner of the org and there isn't already a
  stripe subscription associated with the org

* undo local changes that shouldn't be in prod

* small tweaks to match backend

* fix test

* Update query snapshots

* Update query snapshots

---------

Co-authored-by: Bianca Yang <[email protected]>
Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 21, 2023
1 parent a5e2ca1 commit ecd5256
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 13 deletions.
2 changes: 1 addition & 1 deletion ee/api/test/test_billing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from typing import Any, Dict, List
from unittest.mock import MagicMock, patch
from uuid import uuid4
from zoneinfo import ZoneInfo

import jwt
from zoneinfo import ZoneInfo
from dateutil.relativedelta import relativedelta
from django.utils.timezone import now
from freezegun import freeze_time
Expand Down
11 changes: 10 additions & 1 deletion ee/billing/billing_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
import structlog
from django.utils import timezone
from rest_framework.exceptions import NotAuthenticated
from sentry_sdk import capture_exception

from ee.billing.billing_types import BillingStatus
from ee.billing.quota_limiting import set_org_usage_summary, sync_org_quota_limits
from ee.models import License
from ee.settings import BILLING_SERVICE_URL
from posthog.cloud_utils import get_cached_instance_license
from posthog.models import Organization
from posthog.models.organization import OrganizationUsageInfo
from posthog.models.organization import OrganizationMembership, OrganizationUsageInfo

logger = structlog.get_logger(__name__)

Expand Down Expand Up @@ -114,6 +115,14 @@ def update_billing_distinct_ids(self, organization: Organization) -> None:
distinct_ids = list(organization.members.values_list("distinct_id", flat=True))
self.update_billing(organization, {"distinct_ids": distinct_ids})

def update_billing_customer_email(self, organization: Organization) -> None:
try:
owner_membership = OrganizationMembership.objects.get(organization=organization, level=15)
user = owner_membership.user
self.update_billing(organization, {"org_customer_email": user.email})
except Exception as e:
capture_exception(e)

def deactivate_products(self, organization: Organization, products: str) -> None:
res = requests.get(
f"{BILLING_SERVICE_URL}/api/billing/deactivate?products={products}",
Expand Down
23 changes: 23 additions & 0 deletions ee/billing/test/test_billing_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,26 @@ def test_update_billing_distinct_ids(self, billing_patch_request_mock: MagicMock
BillingManager(license).update_billing_distinct_ids(organization)
assert billing_patch_request_mock.call_count == 1
assert len(billing_patch_request_mock.call_args[1]["json"]["distinct_ids"]) == 2

@patch(
"ee.billing.billing_manager.requests.patch",
return_value=MagicMock(status_code=200, json=MagicMock(return_value={"text": "ok"})),
)
def test_update_billing_customer_email(self, billing_patch_request_mock: MagicMock):
organization = self.organization
license = super(LicenseManager, cast(LicenseManager, License.objects)).create(
key="key123::key123",
plan="enterprise",
valid_until=timezone.datetime(2038, 1, 19, 3, 14, 7),
)
User.objects.create_and_join(
organization=organization,
email="[email protected]",
password=None,
level=OrganizationMembership.Level.OWNER,
)
organization.refresh_from_db()
assert len(organization.members.values_list("distinct_id", flat=True)) == 2 # one exists in the test base
BillingManager(license).update_billing_customer_email(organization)
assert billing_patch_request_mock.call_count == 1
assert billing_patch_request_mock.call_args[1]["json"]["org_customer_email"] == "[email protected]"
5 changes: 4 additions & 1 deletion posthog/api/test/test_signup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from typing import Dict, Optional, cast
from unittest import mock
from unittest.mock import ANY, patch
from zoneinfo import ZoneInfo

import pytest
from zoneinfo import ZoneInfo
from django.core import mail
from django.urls.base import reverse
from django.utils import timezone
Expand Down Expand Up @@ -543,6 +543,7 @@ def test_social_signup_with_allowed_domain_on_self_hosted(

@patch("posthoganalytics.capture")
@mock.patch("ee.billing.billing_manager.BillingManager.update_billing_distinct_ids")
@mock.patch("ee.billing.billing_manager.BillingManager.update_billing_customer_email")
@mock.patch("social_core.backends.base.BaseAuth.request")
@mock.patch("posthog.api.authentication.get_instance_available_sso_providers")
@mock.patch("posthog.tasks.user_identify.identify_task")
Expand All @@ -553,11 +554,13 @@ def test_social_signup_with_allowed_domain_on_cloud(
mock_sso_providers,
mock_request,
mock_update_distinct_ids,
mock_update_billing_customer_email,
mock_capture,
):
with self.is_cloud(True):
self.run_test_for_allowed_domain(mock_sso_providers, mock_request, mock_capture)
assert mock_update_distinct_ids.called_once()
assert mock_update_billing_customer_email.called_once()

@mock.patch("social_core.backends.base.BaseAuth.request")
@mock.patch("posthog.api.authentication.get_instance_available_sso_providers")
Expand Down
19 changes: 9 additions & 10 deletions posthog/models/user.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,5 @@
from functools import cached_property
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
TypedDict,
)
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypedDict

from django.contrib.auth.models import AbstractUser, BaseUserManager
from django.db import models, transaction
Expand Down Expand Up @@ -237,6 +228,8 @@ def join(
# We don't need to check for ExplicitTeamMembership as none can exist for a completely new member
self.current_team = organization.teams.order_by("id").filter(access_control=False).first()
self.save()
if level == OrganizationMembership.Level.OWNER and not self.current_organization.customer_id:
self.update_billing_customer_email(organization)
self.update_billing_distinct_ids(organization)
return membership

Expand Down Expand Up @@ -268,6 +261,12 @@ def update_billing_distinct_ids(self, organization: Organization) -> None:
if is_cloud() and get_cached_instance_license() is not None:
BillingManager(get_cached_instance_license()).update_billing_distinct_ids(organization)

def update_billing_customer_email(self, organization: Organization) -> None:
from ee.billing.billing_manager import BillingManager # avoid circular import

if is_cloud() and get_cached_instance_license() is not None:
BillingManager(get_cached_instance_license()).update_billing_customer_email(organization)

def get_analytics_metadata(self):
team_member_count_all: int = (
OrganizationMembership.objects.filter(organization__in=self.organizations.all())
Expand Down

0 comments on commit ecd5256

Please sign in to comment.