diff --git a/app_users/signals.py b/app_users/signals.py index 8e6b32dcf..bee42c391 100644 --- a/app_users/signals.py +++ b/app_users/signals.py @@ -1,19 +1,54 @@ -# from django.db import transaction -# from django.db.models.signals import post_delete -# from django.dispatch import receiver -# from firebase_admin import auth -# -# from app_users.models import AppUser -# -# -# @receiver(post_delete, sender=AppUser) -# def profile_post_delete(instance: AppUser, **kwargs): -# if not instance.uid: -# return -# -# @transaction.on_commit -# def _(): -# try: -# auth.delete_user(instance.uid) -# except auth.UserNotFoundError: -# pass +import stripe +from loguru import logger +from django.db.models.signals import post_save +from django.dispatch import receiver + +from app_users.models import AppUserTransaction, PaymentProvider, TransactionReason +from payments.plans import PricingPlan +from payments.webhooks import set_user_subscription + + +@receiver(post_save, sender=AppUserTransaction) +def after_stripe_addon(instance: AppUserTransaction, **kwargs): + if not ( + instance.payment_provider == PaymentProvider.STRIPE + and instance.reason == TransactionReason.ADDON + ): + return + + set_default_payment_method(instance) + set_free_subscription_on_user(instance) + + +def set_default_payment_method(instance: AppUserTransaction): + # update customer's defualt payment method + # note... that if a customer has an active subscription, the payment method attached there will be preferred + # see `stripe_get_default_payment_method` in payments/models.py module + invoice = stripe.Invoice.retrieve(instance.invoice_id, expand=["payment_intent"]) + if ( + invoice.payment_intent + and invoice.payment_intent.status == "succeeded" + and invoice.payment_intent.payment_method + ): + logger.info( + f"Updating default payment method for customer {invoice.customer} to {invoice.payment_intent.payment_method}" + ) + stripe.Customer.modify( + invoice.customer, + invoice_settings={ + "default_payment_method": invoice.payment_intent.payment_method + }, + ) + + +def set_free_subscription_on_user(instance: AppUserTransaction): + user = instance.user + if user.subscription: + return + + set_user_subscription( + provider=PaymentProvider.STRIPE, + plan=PricingPlan.STARTER, + uid=user.uid, + external_id=None, + ) diff --git a/daras_ai_v2/billing.py b/daras_ai_v2/billing.py index 09b48d375..ce8084b10 100644 --- a/daras_ai_v2/billing.py +++ b/daras_ai_v2/billing.py @@ -23,7 +23,7 @@ def billing_page(user: AppUser): render_payments_setup() - if user.subscription: + if user.subscription and user.subscription.plan != PricingPlan.STARTER.db_value: render_current_plan(user) with gui.div(className="my-5"): @@ -35,7 +35,7 @@ def billing_page(user: AppUser): with gui.div(className="my-5"): render_addon_section(user, selected_payment_provider) - if user.subscription and user.subscription.payment_provider: + if user.subscription: if user.subscription.payment_provider == PaymentProvider.STRIPE: with gui.div(className="my-5"): render_auto_recharge_section(user) @@ -131,7 +131,7 @@ def render_all_plans(user: AppUser) -> PaymentProvider: plans_div = gui.div(className="mb-1") if user.subscription and user.subscription.payment_provider: - selected_payment_provider = None + selected_payment_provider = user.subscription.payment_provider else: with gui.div(): selected_payment_provider = PaymentProvider[ @@ -149,7 +149,10 @@ def _render_plan(plan: PricingPlan): ): _render_plan_details(plan) _render_plan_action_button( - user, plan, current_plan, selected_payment_provider + user=user, + plan=plan, + current_plan=current_plan, + payment_provider=selected_payment_provider, ) with plans_div: @@ -198,7 +201,9 @@ def _render_plan_action_button( className=btn_classes + " btn btn-theme btn-primary", ): gui.html("Contact Us") - elif user.subscription and not user.subscription.payment_provider: + elif ( + user.subscription and user.subscription.plan == PricingPlan.ENTERPRISE.db_value + ): # don't show upgrade/downgrade buttons for enterprise customers # assumption: anyone without a payment provider attached is admin/enterprise return @@ -208,7 +213,7 @@ def _render_plan_action_button( else: label, btn_type = ("Downgrade", "secondary") - if user.subscription and user.subscription.payment_provider: + if user.subscription and user.subscription.external_id: # subscription exists, show upgrade/downgrade button _render_update_subscription_button( label, @@ -322,7 +327,6 @@ def change_subscription(user: AppUser, new_plan: PricingPlan, **kwargs): if new_plan == PricingPlan.STARTER: user.subscription.cancel() - user.subscription.delete() raise gui.RedirectException( get_app_route_url(payment_processing_route), status_code=303 ) @@ -383,7 +387,7 @@ def render_addon_section(user: AppUser, selected_payment_provider: PaymentProvid gui.write("# Purchase Credits") gui.caption(f"Buy more credits. $1 per {settings.ADDON_CREDITS_PER_DOLLAR} credits") - if user.subscription: + if user.subscription and user.subscription.payment_provider: provider = PaymentProvider(user.subscription.payment_provider) else: provider = selected_payment_provider @@ -423,7 +427,7 @@ def render_stripe_addon_button(dollat_amt: int, user: AppUser): "Confirm Purchase", key=f"confirm-purchase-{dollat_amt}" ) if gui.button(f"${dollat_amt:,}", type="primary"): - if user.subscription: + if user.subscription and user.subscription.payment_provider: confirm_purchase_modal.open() else: stripe_addon_checkout_redirect(user, dollat_amt) @@ -503,7 +507,7 @@ def stripe_subscription_checkout_redirect(user: AppUser, plan: PricingPlan): from routers.account import account_route from routers.account import payment_processing_route - if user.subscription: + if user.subscription and user.subscription.plan == plan.db_value: # already subscribed to some plan return @@ -544,24 +548,28 @@ def render_paypal_subscription_button( def render_payment_information(user: AppUser): - assert user.subscription + if not user.subscription: + return + + pm_summary = gui.run_in_thread( + user.subscription.get_payment_method_summary, cache=True + ) + if not pm_summary: + return gui.write("## Payment Information", id="payment-information", className="d-block") col1, col2, col3 = gui.columns(3, responsive=False) with col1: gui.write("**Pay via**") with col2: - provider = PaymentProvider(user.subscription.payment_provider) + provider = PaymentProvider( + user.subscription.payment_provider or PaymentProvider.STRIPE + ) gui.write(provider.label) with col3: if gui.button(f"{icons.edit} Edit", type="link", key="manage-payment-provider"): raise gui.RedirectException(user.subscription.get_external_management_url()) - pm_summary = gui.run_in_thread( - user.subscription.get_payment_method_summary, cache=True - ) - if not pm_summary: - return pm_summary = PaymentMethodSummary(*pm_summary) if pm_summary.card_brand and pm_summary.card_last4: col1, col2, col3 = gui.columns(3, responsive=False) diff --git a/payments/models.py b/payments/models.py index a7deadc5e..daa397e8a 100644 --- a/payments/models.py +++ b/payments/models.py @@ -5,6 +5,7 @@ import stripe from django.db import models +from django.db.models import Q from django.utils import timezone from app_users.models import PaymentProvider @@ -72,10 +73,14 @@ class Subscription(models.Model): objects = SubscriptionQuerySet.as_manager() class Meta: - unique_together = ("payment_provider", "external_id") - indexes = [ - models.Index(fields=["plan"]), + constraints = [ + models.UniqueConstraint( + fields=["payment_provider", "external_id"], + condition=Q(plan__ne=PricingPlan.STARTER.db_value), + name="unique_provider_and_subscription_id", + ) ] + indexes = [models.Index(fields=["plan"])] def __str__(self): ret = f"{self.get_plan_display()} | {self.get_payment_provider_display()}" @@ -126,6 +131,9 @@ def has_user(self) -> bool: return True def cancel(self): + if not self.external_id: + return + match self.payment_provider: case PaymentProvider.STRIPE: stripe.Subscription.cancel(self.external_id) @@ -156,16 +164,6 @@ def get_next_invoice_timestamp(self) -> float | None: def get_payment_method_summary(self) -> PaymentMethodSummary | None: match self.payment_provider: - case PaymentProvider.STRIPE: - pm = self.stripe_get_default_payment_method() - if not pm: - return None - return PaymentMethodSummary( - payment_method_type=pm.type, - card_brand=pm.card and pm.card.brand, - card_last4=pm.card and pm.card.last4, - billing_email=(pm.billing_details and pm.billing_details.email), - ) case PaymentProvider.PAYPAL: subscription = paypal.Subscription.retrieve(self.external_id) subscriber = subscription.subscriber @@ -178,17 +176,32 @@ def get_payment_method_summary(self) -> PaymentMethodSummary | None: card_last4=source.get("card", {}).get("last_digits"), billing_email=subscriber.email_address, ) + case PaymentProvider.STRIPE | None: + # None is for the case when user doesn't have a subscription, but has their payment + # method on Stripe. we can use this to autopay for their addons or in autorecharge + pm = self.stripe_get_default_payment_method() + if not pm: + return None + return PaymentMethodSummary( + payment_method_type=pm.type, + card_brand=pm.card and pm.card.brand, + card_last4=pm.card and pm.card.last4, + billing_email=(pm.billing_details and pm.billing_details.email), + ) def stripe_get_default_payment_method(self) -> stripe.PaymentMethod | None: - if self.payment_provider != PaymentProvider.STRIPE: - raise ValueError("Invalid Payment Provider") - - subscription = stripe.Subscription.retrieve(self.external_id) - if subscription.default_payment_method: - return stripe.PaymentMethod.retrieve(subscription.default_payment_method) + if self.payment_provider == PaymentProvider.STRIPE and self.external_id: + subscription = stripe.Subscription.retrieve(self.external_id) + if subscription.default_payment_method: + return stripe.PaymentMethod.retrieve( + subscription.default_payment_method + ) - customer = stripe.Customer.retrieve(subscription.customer) - if customer.invoice_settings.default_payment_method: + customer = self.stripe_get_customer() + if ( + customer.invoice_settings + and customer.invoice_settings.default_payment_method + ): return stripe.PaymentMethod.retrieve( customer.invoice_settings.default_payment_method ) @@ -208,12 +221,16 @@ def stripe_get_or_create_auto_invoice( - Fetch a `metadata_key` invoice that was recently paid - Create an invoice with amount=`amount_in_dollars` and `metadata_key` set to true """ - customer_id = self.stripe_get_customer_id() + customer = self.stripe_get_customer() invoices = stripe.Invoice.list( - customer=customer_id, + customer=customer.id, collection_method="charge_automatically", ) - invoices = [inv for inv in invoices.data if metadata_key in inv.metadata] + invoices = [ + inv + for inv in invoices.data + if inv.metadata and metadata_key in inv.metadata + ] open_invoice = next((inv for inv in invoices if inv.status == "open"), None) if open_invoice: @@ -232,16 +249,16 @@ def stripe_get_or_create_auto_invoice( ) def stripe_create_auto_invoice(self, *, amount_in_dollars: int, metadata_key: str): - customer_id = self.stripe_get_customer_id() + customer = self.stripe_get_customer() invoice = stripe.Invoice.create( - customer=customer_id, + customer=customer.id, collection_method="charge_automatically", metadata={metadata_key: True}, auto_advance=False, pending_invoice_items_behavior="exclude", ) stripe.InvoiceItem.create( - customer=customer_id, + customer=customer.id, invoice=invoice, price_data={ "currency": "usd", @@ -257,11 +274,15 @@ def stripe_create_auto_invoice(self, *, amount_in_dollars: int, metadata_key: st invoice.finalize_invoice(auto_advance=True) return invoice - def stripe_get_customer_id(self) -> str: - if self.payment_provider == PaymentProvider.STRIPE: - subscription = stripe.Subscription.retrieve(self.external_id) + def stripe_get_customer(self) -> stripe.Customer: + if self.payment_provider == PaymentProvider.STRIPE and self.external_id: + subscription = stripe.Subscription.retrieve( + self.external_id, expand=["customer"] + ) return subscription.customer - raise ValueError("Invalid Payment Provider") + + assert self.has_user + return self.user.get_or_create_stripe_customer() def stripe_attempt_addon_purchase(self, amount_in_dollars: int) -> bool: from payments.webhooks import StripeWebhookHandler @@ -286,12 +307,6 @@ def get_external_management_url(self) -> str: from routers.account import account_route match self.payment_provider: - case PaymentProvider.STRIPE: - portal = stripe.billing_portal.Session.create( - customer=self.stripe_get_customer_id(), - return_url=get_app_route_url(account_route), - ) - return portal.url case PaymentProvider.PAYPAL: return str( settings.PAYPAL_WEB_BASE_URL @@ -300,6 +315,12 @@ def get_external_management_url(self) -> str: / "connect" / self.external_id ) + case PaymentProvider.STRIPE | None: + portal = stripe.billing_portal.Session.create( + customer=self.stripe_get_customer().id, + return_url=get_app_route_url(account_route), + ) + return portal.url case _: raise NotImplementedError( f"Can't get management URL for subscription with provider {self.payment_provider}" @@ -329,5 +350,9 @@ def should_send_monthly_spending_notification(self) -> bool: def nearest_choice(choices: list[int], value: int) -> int: - # nearest value in choices that is less than or equal to value - return min(filter(lambda x: x <= value, choices), key=lambda x: abs(x - value)) + # nearest choice that is less than or equal to the value (or the minimum choice if value is the least) + le_choices = [choice for choice in choices if choice <= value] + if not le_choices: + return min(choices) + else: + return min(le_choices, key=lambda x: abs(value - x)) diff --git a/payments/webhooks.py b/payments/webhooks.py index b30cae120..741499e16 100644 --- a/payments/webhooks.py +++ b/payments/webhooks.py @@ -59,7 +59,7 @@ def handle_subscription_updated(cls, pp_sub: paypal.Subscription): ) return - _set_user_subscription( + set_user_subscription( provider=cls.PROVIDER, plan=plan, uid=pp_sub.custom_id, @@ -109,20 +109,25 @@ def handle_invoice_paid(cls, uid: str, invoice: stripe.Invoice): @classmethod def handle_checkout_session_completed(cls, uid: str, session_data): - if setup_intent_id := session_data.get("setup_intent") is None: + setup_intent_id = session_data.get("setup_intent") + if not setup_intent_id: # not a setup mode checkout -- do nothing return - setup_intent = stripe.SetupIntent.retrieve(setup_intent_id) - - # subscription_id was passed to metadata when creating the session - sub_id = setup_intent.metadata["subscription_id"] - assert ( - sub_id - ), f"subscription_id is missing in setup_intent metadata {setup_intent}" - stripe.Subscription.modify( - sub_id, default_payment_method=setup_intent.payment_method - ) + setup_intent = stripe.SetupIntent.retrieve(setup_intent_id) + if sub_id := setup_intent.metadata.get("subscription_id"): + # subscription_id was passed to metadata when creating the session + stripe.Subscription.modify( + sub_id, default_payment_method=setup_intent.payment_method + ) + elif customer_id := session_data.get("customer"): + # no subscription_id, so update the customer's default payment method instead + stripe.Customer.modify( + customer_id, + invoice_settings={ + "default_payment_method": setup_intent.payment_method + }, + ) @classmethod def handle_subscription_updated(cls, uid: str, stripe_sub: stripe.Subscription): @@ -146,7 +151,7 @@ def handle_subscription_updated(cls, uid: str, stripe_sub: stripe.Subscription): ) return - _set_user_subscription( + set_user_subscription( provider=cls.PROVIDER, plan=plan, uid=uid, @@ -190,22 +195,37 @@ def add_balance_for_payment( send_monthly_spending_notification_email.delay(user.id) -def _set_user_subscription( - *, provider: PaymentProvider, plan: PricingPlan, uid: str, external_id: str +def set_user_subscription( + *, + provider: PaymentProvider, + plan: PricingPlan, + uid: str, + external_id: str | None, ): + user = AppUser.objects.get_or_create_from_uid(uid)[0] + existing = user.subscription + if existing: + defaults = { + "auto_recharge_enabled": user.subscription.auto_recharge_enabled, + "auto_recharge_topup_amount": user.subscription.auto_recharge_topup_amount, + "auto_recharge_balance_threshold": user.subscription.auto_recharge_balance_threshold, + "monthly_spending_budget": user.subscription.monthly_spending_budget, + "monthly_spending_notification_threshold": user.subscription.monthly_spending_notification_threshold, + } + else: + defaults = {} + with transaction.atomic(): subscription, created = Subscription.objects.get_or_create( payment_provider=provider, external_id=external_id, - defaults=dict(plan=plan.db_value), + defaults=dict(plan=plan.db_value, **defaults), ) + subscription.plan = plan.db_value subscription.full_clean() subscription.save() - user = AppUser.objects.get_or_create_from_uid(uid)[0] - existing = user.subscription - user.subscription = subscription user.save(update_fields=["subscription"]) @@ -224,10 +244,13 @@ def _set_user_subscription( def _remove_subscription_for_user( *, uid: str, provider: PaymentProvider, external_id: str ): - AppUser.objects.filter( - uid=uid, - subscription__payment_provider=provider, - subscription__external_id=external_id, - ).update( - subscription=None, - ) + try: + user = AppUser.objects.get(uid=uid) + except AppUser.DoesNotExist: + logger.warning(f"User {uid} not found") + return + + user.subscription.plan = PricingPlan.STARTER.db_value + user.subscription.payment_provider = None + user.subscription.external_id = None + user.subscription.save(update_fields=["plan", "payment_provider", "external_id"])