Skip to content

Commit

Permalink
feat: set the user id in the billing token (#25441)
Browse files Browse the repository at this point in the history
  • Loading branch information
zlwaterfield authored Oct 10, 2024
1 parent 61d64b5 commit 1bb5016
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 18 deletions.
30 changes: 22 additions & 8 deletions ee/api/billing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from rest_framework.exceptions import NotFound, PermissionDenied, ValidationError
from rest_framework.request import Request
from rest_framework.response import Response
from django.contrib.auth.models import AbstractUser

from ee.billing.billing_manager import BillingManager, build_billing_token
from ee.models import License
Expand Down Expand Up @@ -40,6 +41,13 @@ class BillingViewset(TeamAndOrgViewSetMixin, viewsets.GenericViewSet):

scope_object = "INTERNAL"

def get_billing_manager(self) -> BillingManager:
license = get_cached_instance_license()
user = (
self.request.user if isinstance(self.request.user, AbstractUser) and self.request.user.distinct_id else None
)
return BillingManager(license, user)

def list(self, request: Request, *args: Any, **kwargs: Any) -> Response:
license = get_cached_instance_license()
if license and not license.is_v2_license:
Expand All @@ -53,7 +61,8 @@ def list(self, request: Request, *args: Any, **kwargs: Any) -> Response:
raise NotFound("Billing V1 is active for this organization")

plan_keys = request.query_params.get("plan_keys", None)
response = BillingManager(license).get_billing(org, plan_keys)
billing_manager = self.get_billing_manager()
response = billing_manager.get_billing(org, plan_keys)

return Response(response)

Expand All @@ -68,7 +77,8 @@ def patch(self, request: Request, *args: Any, **kwargs: Any) -> Response:
if license and org: # for mypy
custom_limits_usd = request.data.get("custom_limits_usd")
if custom_limits_usd:
BillingManager(license).update_billing(org, {"custom_limits_usd": custom_limits_usd})
billing_manager = self.get_billing_manager()
billing_manager.update_billing(org, {"custom_limits_usd": custom_limits_usd})

if distinct_id:
posthoganalytics.capture(
Expand Down Expand Up @@ -153,7 +163,6 @@ class DeactivateSerializer(serializers.Serializer):

@action(methods=["GET"], detail=False)
def deactivate(self, request: Request, *args: Any, **kwargs: Any) -> HttpResponse:
license = get_cached_instance_license()
organization = self._get_org_required()

serializer = self.DeactivateSerializer(data=request.GET)
Expand All @@ -162,7 +171,8 @@ def deactivate(self, request: Request, *args: Any, **kwargs: Any) -> HttpRespons
products = serializer.validated_data.get("products")

try:
BillingManager(license).deactivate_products(organization, products)
billing_manager = self.get_billing_manager()
billing_manager.deactivate_products(organization, products)
except Exception as e:
if len(e.args) > 2:
detail_object = e.args[2]
Expand Down Expand Up @@ -191,7 +201,8 @@ def portal(self, request: Request, *args: Any, **kwargs: Any) -> HttpResponse:

organization = self._get_org_required()

res = BillingManager(license)._get_stripe_portal_url(organization)
billing_manager = self.get_billing_manager()
res = billing_manager._get_stripe_portal_url(organization)
return redirect(res)

@action(methods=["GET"], detail=False)
Expand All @@ -208,7 +219,8 @@ def get_invoices(self, request: Request, *args: Any, **kwargs: Any) -> HttpRespo
invoice_status = request.GET.get("status")

try:
res = BillingManager(license).get_invoices(organization, status=invoice_status)
billing_manager = self.get_billing_manager()
res = billing_manager.get_invoices(organization, status=invoice_status)
except Exception as e:
if len(e.args) > 2:
detail_object = e.args[2]
Expand Down Expand Up @@ -244,7 +256,8 @@ def credits_overview(self, request: Request, *args: Any, **kwargs: Any) -> HttpR

organization = self._get_org_required()

res = BillingManager(license).credits_overview(organization)
billing_manager = self.get_billing_manager()
res = billing_manager.credits_overview(organization)
return Response(res, status=status.HTTP_200_OK)

@action(methods=["POST"], detail=False, url_path="credits/purchase")
Expand All @@ -258,7 +271,8 @@ def purchase_credits(self, request: Request, *args: Any, **kwargs: Any) -> HttpR

organization = self._get_org_required()

res = BillingManager(license).purchase_credits(organization, request.data)
billing_manager = self.get_billing_manager()
res = billing_manager.purchase_credits(organization, request.data)
return Response(res, status=status.HTTP_200_OK)

@action(methods=["PATCH"], detail=False)
Expand Down
1 change: 1 addition & 0 deletions ee/api/test/test_billing.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ def mock_implementation(url: str, headers: Any = None, params: Any = None) -> Ma

assert decoded_token == {
"aud": "posthog:license-key",
"distinct_id": str(self.user.distinct_id),
"exp": 1640996100,
"id": self.license.key.split("::")[0],
"organization_id": str(self.organization.id),
Expand Down
28 changes: 18 additions & 10 deletions ee/billing/billing_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from posthog.cloud_utils import get_cached_instance_license
from posthog.models import Organization
from posthog.models.organization import OrganizationMembership, OrganizationUsageInfo
from posthog.models.user import User

logger = structlog.get_logger(__name__)

Expand All @@ -28,21 +29,26 @@ class BillingAPIErrorCodes(Enum):
OPEN_INVOICES_ERROR = "open_invoices_error"


def build_billing_token(license: License, organization: Organization):
def build_billing_token(license: License, organization: Organization, user: Optional[User] = None):
if not organization or not license:
raise NotAuthenticated()

license_id = license.key.split("::")[0]
license_secret = license.key.split("::")[1]

payload = {
"exp": datetime.now(tz=timezone.utc) + timedelta(minutes=15),
"id": license_id,
"organization_id": str(organization.id),
"organization_name": organization.name,
"aud": "posthog:license-key",
}

if user:
payload["distinct_id"] = str(user.distinct_id)

encoded_jwt = jwt.encode(
{
"exp": datetime.now(tz=timezone.utc) + timedelta(minutes=15),
"id": license_id,
"organization_id": str(organization.id),
"organization_name": organization.name,
"aud": "posthog:license-key",
},
payload,
license_secret,
algorithm="HS256",
)
Expand All @@ -62,9 +68,11 @@ def handle_billing_service_error(res: requests.Response, valid_codes=(200, 404,

class BillingManager:
license: Optional[License]
user: Optional[User]

def __init__(self, license):
def __init__(self, license, user: Optional[User] = None):
self.license = license or get_cached_instance_license()
self.user = user

def get_billing(self, organization: Optional[Organization], plan_keys: Optional[str]) -> dict[str, Any]:
if organization and self.license and self.license.is_v2_license:
Expand Down Expand Up @@ -331,7 +339,7 @@ def update_org_details(self, organization: Organization, billing_status: Billing
def get_auth_headers(self, organization: Organization):
if not self.license: # mypy
raise Exception("No license found")
billing_service_token = build_billing_token(self.license, organization)
billing_service_token = build_billing_token(self.license, organization, self.user)
return {"Authorization": f"Bearer {billing_service_token}"}

def get_invoices(self, organization: Organization, status: Optional[str]):
Expand Down

0 comments on commit 1bb5016

Please sign in to comment.