diff --git a/app_users/apps.py b/app_users/apps.py index 023d4ae99..e8d750128 100644 --- a/app_users/apps.py +++ b/app_users/apps.py @@ -5,8 +5,3 @@ class AppUsersConfig(AppConfig): default_auto_field = "django.db.models.BigAutoField" name = "app_users" verbose_name = "App Users" - - def ready(self): - from . import signals - - assert signals diff --git a/app_users/signals.py b/app_users/signals.py deleted file mode 100644 index 8e6b32dcf..000000000 --- a/app_users/signals.py +++ /dev/null @@ -1,19 +0,0 @@ -# from django.db import transaction -# from django.db.models.signals import post_delete -# from django.dispatch import receiver -# from firebase_admin import auth -# -# from app_users.models import AppUser -# -# -# @receiver(post_delete, sender=AppUser) -# def profile_post_delete(instance: AppUser, **kwargs): -# if not instance.uid: -# return -# -# @transaction.on_commit -# def _(): -# try: -# auth.delete_user(instance.uid) -# except auth.UserNotFoundError: -# pass diff --git a/app_users/tasks.py b/app_users/tasks.py new file mode 100644 index 000000000..0327ac423 --- /dev/null +++ b/app_users/tasks.py @@ -0,0 +1,54 @@ +import stripe +from loguru import logger + +from app_users.models import PaymentProvider, TransactionReason +from celeryapp.celeryconfig import app +from payments.models import Subscription +from payments.plans import PricingPlan +from payments.webhooks import set_user_subscription + + +@app.task +def save_stripe_default_payment_method( + *, + payment_intent_id: str, + uid: str, + amount: int, + charged_amount: int, + reason: TransactionReason, +): + pi = stripe.PaymentIntent.retrieve(payment_intent_id, expand=["payment_method"]) + pm = pi.payment_method + if not (pm and pm.customer): + logger.error( + f"Failed to retrieve payment method for payment intent {payment_intent_id}" + ) + return + + # update customer's defualt payment method + # note: if a customer has an active subscription, the payment method attached there will be preferred + # see `stripe_get_default_payment_method` in payments/models.py module + logger.info( + f"Updating default payment method for customer {pm.customer} to {pm.id}" + ) + stripe.Customer.modify( + pm.customer, + invoice_settings=dict(default_payment_method=pm), + ) + + # if user doesn't already have a active billing/autorecharge info, so we don't need to do anything + # set user's subscription to the free plan + if ( + reason == TransactionReason.ADDON + and not Subscription.objects.filter( + user__uid=uid, payment_provider__isnull=False + ).exists() + ): + set_user_subscription( + uid=uid, + plan=PricingPlan.STARTER, + provider=PaymentProvider.STRIPE, + external_id=None, + amount=amount, + charged_amount=charged_amount, + ) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index fbbc7e60b..9e1c679fb 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -52,6 +52,7 @@ from daras_ai_v2.exceptions import InsufficientCredits from daras_ai_v2.fastapi_tricks import get_route_path from daras_ai_v2.grid_layout_widget import grid_layout +from daras_ai_v2.gui_confirm import confirm_modal from daras_ai_v2.html_spinner_widget import html_spinner from daras_ai_v2.manage_api_keys_widget import manage_api_keys from daras_ai_v2.meta_preview_url import meta_preview_url @@ -697,10 +698,29 @@ def _render_options_modal( save_as_new_button = gui.button( f"{save_as_new_icon} Save as New", className="w-100" ) - delete_button = not published_run.is_root() and gui.button( - f' Delete', - className="w-100 text-danger", - ) + + if not published_run.is_root(): + confirm_delete_modal, confirmed = confirm_modal( + title="Are you sure?", + key="--delete-run-modal", + text=f""" +Are you sure you want to delete this published run? + +**{published_run.title}** + +This will also delete all the associated versions. + """, + button_label="Delete", + button_class="border-danger bg-danger text-white", + ) + if gui.button( + f' Delete', + className="w-100 text-danger", + ): + confirm_delete_modal.open() + if confirmed: + published_run.delete() + raise gui.RedirectException(self.app_url()) if duplicate_button: duplicate_pr = self.duplicate_published_run( @@ -730,47 +750,6 @@ def _render_options_modal( gui.write("#### Version History", className="mb-4") self._render_version_history() - confirm_delete_modal = gui.Modal("Confirm Delete", key="confirm-delete-modal") - if delete_button: - confirm_delete_modal.open() - if confirm_delete_modal.is_open(): - modal.empty() - with confirm_delete_modal.container(): - self._render_confirm_delete_modal( - published_run=published_run, - modal=confirm_delete_modal, - ) - - def _render_confirm_delete_modal( - self, - *, - published_run: PublishedRun, - modal: gui.Modal, - ): - gui.write( - "Are you sure you want to delete this published run? " - f"_({published_run.title})_" - ) - gui.caption("This will also delete all the associated versions.") - with gui.div(className="d-flex"): - confirm_button = gui.button( - 'Confirm', - type="secondary", - className="w-100", - ) - cancel_button = gui.button( - "Cancel", - type="secondary", - className="w-100", - ) - - if confirm_button: - published_run.delete() - raise gui.RedirectException(self.app_url()) - - if cancel_button: - modal.close() - def _render_admin_options(self, current_run: SavedRun, published_run: PublishedRun): if ( not self.is_current_user_admin() diff --git a/daras_ai_v2/billing.py b/daras_ai_v2/billing.py index 09b48d375..25e607415 100644 --- a/daras_ai_v2/billing.py +++ b/daras_ai_v2/billing.py @@ -1,5 +1,3 @@ -from typing import Literal - import gooey_gui as gui import stripe from django.core.exceptions import ValidationError @@ -8,22 +6,21 @@ from daras_ai_v2 import icons, settings, paypal from daras_ai_v2.fastapi_tricks import get_app_route_url from daras_ai_v2.grid_layout_widget import grid_layout +from daras_ai_v2.gui_confirm import confirm_modal from daras_ai_v2.settings import templates from daras_ai_v2.user_date_widgets import render_local_date_attrs from payments.models import PaymentMethodSummary from payments.plans import PricingPlan +from payments.webhooks import StripeWebhookHandler, set_user_subscription from scripts.migrate_existing_subscriptions import available_subscriptions rounded_border = "w-100 border shadow-sm rounded py-4 px-3" -PlanActionLabel = Literal["Upgrade", "Downgrade", "Contact Us", "Your Plan"] - - def billing_page(user: AppUser): render_payments_setup() - if user.subscription: + if user.subscription and user.subscription.is_paid(): render_current_plan(user) with gui.div(className="my-5"): @@ -35,10 +32,11 @@ def billing_page(user: AppUser): with gui.div(className="my-5"): render_addon_section(user, selected_payment_provider) - if user.subscription and user.subscription.payment_provider: + if user.subscription: if user.subscription.payment_provider == PaymentProvider.STRIPE: with gui.div(className="my-5"): render_auto_recharge_section(user) + with gui.div(className="my-5"): render_payment_information(user) @@ -59,11 +57,10 @@ def render_payments_setup(): def render_current_plan(user: AppUser): plan = PricingPlan.from_sub(user.subscription) - provider = ( - PaymentProvider(user.subscription.payment_provider) - if user.subscription.payment_provider - else None - ) + if user.subscription.payment_provider: + provider = PaymentProvider(user.subscription.payment_provider) + else: + provider = None with gui.div(className=f"{rounded_border} border-dark"): # ROW 1: Plan title and next invoice date @@ -101,7 +98,10 @@ def render_current_plan(user: AppUser): with left: gui.write(f"# {plan.pricing_title()}", className="no-margin") if plan.monthly_charge: - provider_text = f" **via {provider.label}**" if provider else "" + if provider: + provider_text = f" **via {provider.label}**" + else: + provider_text = "" gui.caption("per month" + provider_text) with right, gui.div(className="text-end"): @@ -131,7 +131,7 @@ def render_all_plans(user: AppUser) -> PaymentProvider: plans_div = gui.div(className="mb-1") if user.subscription and user.subscription.payment_provider: - selected_payment_provider = None + selected_payment_provider = user.subscription.payment_provider else: with gui.div(): selected_payment_provider = PaymentProvider[ @@ -149,7 +149,10 @@ def _render_plan(plan: PricingPlan): ): _render_plan_details(plan) _render_plan_action_button( - user, plan, current_plan, selected_payment_provider + user=user, + plan=plan, + current_plan=current_plan, + payment_provider=selected_payment_provider, ) with plans_div: @@ -198,30 +201,59 @@ def _render_plan_action_button( className=btn_classes + " btn btn-theme btn-primary", ): gui.html("Contact Us") - elif user.subscription and not user.subscription.payment_provider: + elif ( + user.subscription and user.subscription.plan == PricingPlan.ENTERPRISE.db_value + ): # don't show upgrade/downgrade buttons for enterprise customers - # assumption: anyone without a payment provider attached is admin/enterprise return else: - if plan.credits > current_plan.credits: - label, btn_type = ("Upgrade", "primary") - else: - label, btn_type = ("Downgrade", "secondary") - - if user.subscription and user.subscription.payment_provider: + if user.subscription and user.subscription.is_paid(): # subscription exists, show upgrade/downgrade button - _render_update_subscription_button( - label, - user=user, - current_plan=current_plan, - plan=plan, - className=f"{btn_classes} btn btn-theme btn-{btn_type}", - ) + if plan.credits > current_plan.credits: + modal, confirmed = confirm_modal( + title="Upgrade Plan", + key=f"--modal-{plan.key}", + text=f""" +Are you sure you want to upgrade from: **{current_plan.title} ({fmt_price(current_plan)})** to **{plan.title} ({fmt_price(plan)})**? + +This will charge you the full amount today, and every month thereafter. + +**{current_plan.credits:,} credits** will be added to your account. + """, + button_label="Buy", + ) + if gui.button( + "Upgrade", className="primary", key=f"--change-sub-{plan.key}" + ): + modal.open() + if confirmed: + change_subscription( + user, + plan, + # when upgrading, charge the full new amount today: https://docs.stripe.com/billing/subscriptions/billing-cycle#reset-the-billing-cycle-to-the-current-time + billing_cycle_anchor="now", + ) + else: + modal, confirmed = confirm_modal( + title="Downgrade Plan", + key=f"--modal-{plan.key}", + text=f""" +Are you sure you want to downgrade from: **{current_plan.title} ({fmt_price(current_plan)})** to **{plan.title} ({fmt_price(plan)})**? + +This will take effect from the next billing cycle. + """, + button_label="Downgrade", + button_class="border-danger bg-danger text-white", + ) + if gui.button( + "Downgrade", className="secondary", key=f"--change-sub-{plan.key}" + ): + modal.open() + if confirmed: + change_subscription(user, plan) else: assert payment_provider is not None # for sanity _render_create_subscription_button( - label, - btn_type=btn_type, user=user, plan=plan, payment_provider=payment_provider, @@ -229,81 +261,18 @@ def _render_plan_action_button( def _render_create_subscription_button( - label: PlanActionLabel, *, - btn_type: str, user: AppUser, plan: PricingPlan, payment_provider: PaymentProvider, ): match payment_provider: case PaymentProvider.STRIPE: - key = f"stripe-sub-{plan.key}" - render_stripe_subscription_button( - user=user, - label=label, - plan=plan, - btn_type=btn_type, - key=key, - ) + render_stripe_subscription_button(user=user, plan=plan) case PaymentProvider.PAYPAL: render_paypal_subscription_button(plan=plan) -def _render_update_subscription_button( - label: PlanActionLabel, - *, - user: AppUser, - current_plan: PricingPlan, - plan: PricingPlan, - className: str = "", -): - key = f"change-sub-{plan.key}" - match label: - case "Downgrade": - downgrade_modal = gui.Modal( - "Confirm downgrade", - key=f"downgrade-plan-modal-{plan.key}", - ) - if gui.button( - label, - className=className, - key=key, - ): - downgrade_modal.open() - - if downgrade_modal.is_open(): - with downgrade_modal.container(): - gui.write( - f""" - Are you sure you want to change from: - **{current_plan.title} ({fmt_price(current_plan)})** to **{plan.title} ({fmt_price(plan)})**? - """, - className="d-block py-4", - ) - with gui.div(className="d-flex w-100"): - if gui.button( - "Downgrade", - className="btn btn-theme bg-danger border-danger text-white", - key=f"{key}-confirm", - ): - change_subscription(user, plan) - if gui.button( - "Cancel", - className="border border-danger text-danger", - key=f"{key}-cancel", - ): - downgrade_modal.close() - case _: - if gui.button(label, className=className, key=key): - change_subscription( - user, - plan, - # when upgrading, charge the full new amount today: https://docs.stripe.com/billing/subscriptions/billing-cycle#reset-the-billing-cycle-to-the-current-time - billing_cycle_anchor="now", - ) - - def fmt_price(plan: PricingPlan) -> str: if plan.monthly_charge: return f"${plan.monthly_charge:,}/month" @@ -322,7 +291,6 @@ def change_subscription(user: AppUser, new_plan: PricingPlan, **kwargs): if new_plan == PricingPlan.STARTER: user.subscription.cancel() - user.subscription.delete() raise gui.RedirectException( get_app_route_url(payment_processing_route), status_code=303 ) @@ -383,7 +351,7 @@ def render_addon_section(user: AppUser, selected_payment_provider: PaymentProvid gui.write("# Purchase Credits") gui.caption(f"Buy more credits. $1 per {settings.ADDON_CREDITS_PER_DOLLAR} credits") - if user.subscription: + if user.subscription and user.subscription.payment_provider: provider = PaymentProvider(user.subscription.payment_provider) else: provider = selected_payment_provider @@ -414,57 +382,66 @@ def render_paypal_addon_buttons(): def render_stripe_addon_buttons(user: AppUser): + if not (user.subscription and user.subscription.payment_provider): + save_pm = gui.checkbox( + "Save payment method for future purchases & auto-recharge", value=True + ) + else: + save_pm = True + for dollat_amt in settings.ADDON_AMOUNT_CHOICES: - render_stripe_addon_button(dollat_amt, user) + render_stripe_addon_button(dollat_amt, user, save_pm) + error = gui.session_state.pop("--addon-purchase-error", None) + if error: + gui.error(error) -def render_stripe_addon_button(dollat_amt: int, user: AppUser): - confirm_purchase_modal = gui.Modal( - "Confirm Purchase", key=f"confirm-purchase-{dollat_amt}" + +def render_stripe_addon_button(dollat_amt: int, user: AppUser, save_pm: bool): + modal, confirmed = confirm_modal( + title="Purchase Credits", + key=f"--addon-modal-{dollat_amt}", + text=f""" +Please confirm your purchase: **{dollat_amt * settings.ADDON_CREDITS_PER_DOLLAR:,} credits for ${dollat_amt}**. + +This is a one-time purchase. Your account will be credited immediately. + """, + button_label="Buy", + text_on_confirm="Processing Payment...", ) + if gui.button(f"${dollat_amt:,}", type="primary"): - if user.subscription: - confirm_purchase_modal.open() + if user.subscription and user.subscription.stripe_get_default_payment_method(): + modal.open() else: - stripe_addon_checkout_redirect(user, dollat_amt) + stripe_addon_checkout_redirect(user, dollat_amt, save_pm) - if not confirm_purchase_modal.is_open(): - return - with confirm_purchase_modal.container(): - gui.write( - f""" - Please confirm your purchase: - **{dollat_amt * settings.ADDON_CREDITS_PER_DOLLAR:,} credits for ${dollat_amt}**. - """, - className="py-4 d-block text-center", + if confirmed: + success = gui.run_in_thread( + user.subscription.stripe_attempt_addon_purchase, + args=[dollat_amt], + placeholder="", ) - with gui.div(className="d-flex w-100 justify-content-end"): - if gui.session_state.get("--confirm-purchase"): - success = gui.run_in_thread( - user.subscription.stripe_attempt_addon_purchase, - args=[dollat_amt], - placeholder="Processing payment...", - ) - if success is None: - return - gui.session_state.pop("--confirm-purchase") - if success: - confirm_purchase_modal.close() - else: - gui.error("Payment failed... Please try again.") - return - - if gui.button("Cancel", className="border border-danger text-danger me-2"): - confirm_purchase_modal.close() - gui.button("Buy", type="primary", key="--confirm-purchase") + if success is None: + return + if not success: + gui.session_state["--addon-purchase-error"] = ( + "Payment failed... Please try again or contact us at support@gooey.ai" + ) + modal.close() -def stripe_addon_checkout_redirect(user: AppUser, dollat_amt: int): +def stripe_addon_checkout_redirect(user: AppUser, dollat_amt: int, save_pm: bool): from routers.account import account_route from routers.account import payment_processing_route line_item = available_subscriptions["addon"]["stripe"].copy() line_item["quantity"] = dollat_amt * settings.ADDON_CREDITS_PER_DOLLAR + kwargs = {} + if save_pm: + kwargs["payment_intent_data"] = {"setup_future_usage": "on_session"} + else: + kwargs["saved_payment_method_options"] = {"payment_method_save": "enabled"} checkout_session = stripe.checkout.Session.create( line_items=[line_item], mode="payment", @@ -473,55 +450,98 @@ def stripe_addon_checkout_redirect(user: AppUser, dollat_amt: int): customer=user.get_or_create_stripe_customer(), invoice_creation={"enabled": True}, allow_promotion_codes=True, - saved_payment_method_options={ - "payment_method_save": "enabled", - }, + **kwargs, ) raise gui.RedirectException(checkout_session.url, status_code=303) def render_stripe_subscription_button( *, - label: str, user: AppUser, plan: PricingPlan, - btn_type: str, - key: str, ): if not plan.supports_stripe(): gui.write("Stripe subscription not available") return + modal, confirmed = confirm_modal( + title="Upgrade Plan", + key=f"--modal-{plan.key}", + text=f""" +Are you sure you want to subscribe to **{plan.title} ({fmt_price(plan)})**? + +This will charge you the full amount today, and every month thereafter. + +**{plan.credits:,} credits** will be added to your account. + """, + button_label="Buy", + ) + # IMPORTANT: key=... is needed here to maintain uniqueness # of buttons with the same label. otherwise, all buttons # will be the same to the server - if gui.button(label, key=key, type=btn_type): - stripe_subscription_checkout_redirect(user=user, plan=plan) + if gui.button( + "Upgrade", + key=f"--change-sub-{plan.key}", + type="primary", + ): + if user.subscription and user.subscription.stripe_get_default_payment_method(): + modal.open() + else: + stripe_subscription_create(user=user, plan=plan) + if confirmed: + stripe_subscription_create(user=user, plan=plan) -def stripe_subscription_checkout_redirect(user: AppUser, plan: PricingPlan): + +def stripe_subscription_create(user: AppUser, plan: PricingPlan): from routers.account import account_route from routers.account import payment_processing_route - if user.subscription: + if user.subscription and user.subscription.plan == plan.db_value: # already subscribed to some plan return + pm = user.subscription and user.subscription.stripe_get_default_payment_method() metadata = {settings.STRIPE_USER_SUBSCRIPTION_METADATA_FIELD: plan.key} - checkout_session = stripe.checkout.Session.create( - line_items=[(plan.get_stripe_line_item())], - mode="subscription", - success_url=get_app_route_url(payment_processing_route), - cancel_url=get_app_route_url(account_route), - customer=user.get_or_create_stripe_customer(), - metadata=metadata, - subscription_data={"metadata": metadata}, - allow_promotion_codes=True, - saved_payment_method_options={ - "payment_method_save": "enabled", - }, - ) - raise gui.RedirectException(checkout_session.url, status_code=303) + line_items = [plan.get_stripe_line_item()] + if pm: + # directly create the subscription without checkout + stripe.Subscription.create( + customer=pm.customer, + items=line_items, + metadata=metadata, + default_payment_method=pm.id, + proration_behavior="none", + ) + raise gui.RedirectException( + get_app_route_url(payment_processing_route), status_code=303 + ) + else: + # check for existing subscriptions + customer = user.get_or_create_stripe_customer() + for sub in stripe.Subscription.list( + customer=customer, status="active", limit=1 + ).data: + StripeWebhookHandler.handle_subscription_updated( + uid=user.uid, stripe_sub=sub + ) + raise gui.RedirectException( + get_app_route_url(payment_processing_route), status_code=303 + ) + + checkout_session = stripe.checkout.Session.create( + mode="subscription", + success_url=get_app_route_url(payment_processing_route), + cancel_url=get_app_route_url(account_route), + allow_promotion_codes=True, + customer=customer, + line_items=line_items, + metadata=metadata, + subscription_data={"metadata": metadata}, + saved_payment_method_options={"payment_method_save": "enabled"}, + ) + raise gui.RedirectException(checkout_session.url, status_code=303) def render_paypal_subscription_button( @@ -544,44 +564,90 @@ def render_paypal_subscription_button( def render_payment_information(user: AppUser): - assert user.subscription - - gui.write("## Payment Information", id="payment-information", className="d-block") - col1, col2, col3 = gui.columns(3, responsive=False) - with col1: - gui.write("**Pay via**") - with col2: - provider = PaymentProvider(user.subscription.payment_provider) - gui.write(provider.label) - with col3: - if gui.button(f"{icons.edit} Edit", type="link", key="manage-payment-provider"): - raise gui.RedirectException(user.subscription.get_external_management_url()) + if not user.subscription: + return pm_summary = gui.run_in_thread( user.subscription.get_payment_method_summary, cache=True ) if not pm_summary: return - pm_summary = PaymentMethodSummary(*pm_summary) - if pm_summary.card_brand and pm_summary.card_last4: + + gui.write("## Payment Information", id="payment-information", className="d-block") + with gui.div(className="ps-1"): col1, col2, col3 = gui.columns(3, responsive=False) with col1: - gui.write("**Payment Method**") + gui.write("**Pay via**") with col2: - gui.write( - f"{format_card_brand(pm_summary.card_brand)} ending in {pm_summary.card_last4}", - unsafe_allow_html=True, + provider = PaymentProvider( + user.subscription.payment_provider or PaymentProvider.STRIPE ) + gui.write(provider.label) with col3: - if gui.button(f"{icons.edit} Edit", type="link", key="edit-payment-method"): - change_payment_method(user) + if gui.button( + f"{icons.edit} Edit", type="link", key="manage-payment-provider" + ): + raise gui.RedirectException( + user.subscription.get_external_management_url() + ) - if pm_summary.billing_email: - col1, col2, _ = gui.columns(3, responsive=False) - with col1: - gui.write("**Billing Email**") - with col2: - gui.html(pm_summary.billing_email) + pm_summary = PaymentMethodSummary(*pm_summary) + if pm_summary.card_brand: + col1, col2, col3 = gui.columns(3, responsive=False) + with col1: + gui.write("**Payment Method**") + with col2: + if pm_summary.card_last4: + gui.write( + f"{format_card_brand(pm_summary.card_brand)} ending in {pm_summary.card_last4}", + unsafe_allow_html=True, + ) + else: + gui.write(pm_summary.card_brand) + with col3: + if gui.button( + f"{icons.edit} Edit", type="link", key="edit-payment-method" + ): + change_payment_method(user) + + if pm_summary.billing_email: + col1, col2, _ = gui.columns(3, responsive=False) + with col1: + gui.write("**Billing Email**") + with col2: + gui.html(pm_summary.billing_email) + + from routers.account import payment_processing_route + + modal, confirmed = confirm_modal( + title="Delete Payment Information", + key="--delete-payment-method", + text=""" +Are you sure you want to delete your payment information? + +This will cancel your subscription and remove your saved payment method. + """, + button_label="Delete", + button_class="border-danger bg-danger text-white", + ) + if gui.button( + "Delete & Cancel Subscription", + className="border-danger text-danger", + ): + modal.open() + if confirmed: + set_user_subscription( + uid=user.uid, + plan=PricingPlan.STARTER, + provider=None, + external_id=None, + ) + pm = user.subscription and user.subscription.stripe_get_default_payment_method() + if pm: + pm.detach() + raise gui.RedirectException( + get_app_route_url(payment_processing_route), status_code=303 + ) def change_payment_method(user: AppUser): @@ -636,10 +702,7 @@ def render_billing_history(user: AppUser, limit: int = 50): def render_auto_recharge_section(user: AppUser): - assert ( - user.subscription - and user.subscription.payment_provider == PaymentProvider.STRIPE - ) + assert user.subscription subscription = user.subscription gui.write("## Auto Recharge & Limits") diff --git a/daras_ai_v2/gui_confirm.py b/daras_ai_v2/gui_confirm.py new file mode 100644 index 000000000..dd719d7a5 --- /dev/null +++ b/daras_ai_v2/gui_confirm.py @@ -0,0 +1,42 @@ +import gooey_gui as gui + +from daras_ai_v2.html_spinner_widget import html_spinner + + +def confirm_modal( + *, + title: str, + key: str, + text: str, + button_label: str, + button_class: str = "", + text_on_confirm: str | None = None, +) -> tuple[gui.Modal, bool]: + modal = gui.Modal(title, key=key) + confirmed_key = f"{key}-confirmed" + if modal.is_open(): + with modal.container(): + with gui.div(className="pt-4 pb-3"): + gui.write(text) + with gui.div(className="d-flex w-100 justify-content-end"): + confirmed = bool(gui.session_state.get(confirmed_key, None)) + if confirmed and text_on_confirm: + html_spinner(text_on_confirm) + else: + if gui.button( + "Cancel", + type="tertiary", + className="me-2", + key=f"{key}-cancelled", + ): + modal.close() + confirmed = gui.button( + button_label, + type="primary", + key=confirmed_key, + className=button_class, + ) + return modal, confirmed + else: + gui.session_state.pop(confirmed_key, None) + return modal, False diff --git a/payments/admin.py b/payments/admin.py index c889bc109..9fd8f910f 100644 --- a/payments/admin.py +++ b/payments/admin.py @@ -5,4 +5,14 @@ @admin.register(Subscription) class SubscriptionAdmin(admin.ModelAdmin): - search_fields = ["plan", "payment_provider", "external_id"] + search_fields = [ + "plan", + "payment_provider", + "external_id", + ] + readonly_fields = [ + "user", + "created_at", + "updated_at", + "get_payment_method_summary", + ] diff --git a/payments/auto_recharge.py b/payments/auto_recharge.py index 1aa414cfb..14d6ba49d 100644 --- a/payments/auto_recharge.py +++ b/payments/auto_recharge.py @@ -107,6 +107,9 @@ def _auto_recharge_user(user: AppUser): # get default payment method and attempt payment assert invoice.status == "open" # sanity check pm = user.subscription.stripe_get_default_payment_method() + if not pm: + logger.warning(f"{user} has no default payment method, cannot auto-recharge") + return try: invoice_data = invoice.pay(payment_method=pm) diff --git a/payments/models.py b/payments/models.py index a7deadc5e..5a200ba1e 100644 --- a/payments/models.py +++ b/payments/models.py @@ -5,6 +5,7 @@ import stripe from django.db import models +from django.db.models import Q from django.utils import timezone from app_users.models import PaymentProvider @@ -72,7 +73,13 @@ class Subscription(models.Model): objects = SubscriptionQuerySet.as_manager() class Meta: - unique_together = ("payment_provider", "external_id") + constraints = [ + models.UniqueConstraint( + fields=["payment_provider", "external_id"], + condition=Q(plan__monthly_charge__gt=0), + name="unique_provider_and_subscription_id", + ) + ] indexes = [ models.Index(fields=["plan"]), ] @@ -85,37 +92,39 @@ def __str__(self): ret = f"Auto | {ret}" return ret - def full_clean(self, *args, **kwargs): + def full_clean( + self, amount: int = None, charged_amount: int = None, *args, **kwargs + ): + if self.auto_recharge_enabled: + if amount is None: + amount = PricingPlan.from_sub(self).credits + if charged_amount is None: + charged_amount = PricingPlan.from_sub(self).monthly_charge * 100 + self.ensure_default_auto_recharge_params( + amount=amount, charged_amount=charged_amount + ) + return super().full_clean(*args, **kwargs) + + def ensure_default_auto_recharge_params(self, *, amount: int, charged_amount: int): + if amount <= 0 or charged_amount <= 0: + return + if not self.auto_recharge_balance_threshold: - self.auto_recharge_balance_threshold = ( - self._get_default_auto_recharge_balance_threshold() + # 25% of the credits + self.auto_recharge_balance_threshold = nearest_choice( + settings.AUTO_RECHARGE_BALANCE_THRESHOLD_CHOICES, 0.25 * amount ) if not self.monthly_spending_budget: - self.monthly_spending_budget = self._get_default_monthly_spending_budget() + # 3x the charged amount + self.monthly_spending_budget = 3 * charged_amount / 100 # in dollars if not self.monthly_spending_notification_threshold: - self.monthly_spending_notification_threshold = ( - self._get_default_monthly_spending_notification_threshold() + # 80% of the monthly budget + self.monthly_spending_notification_threshold = int( + 0.8 * self.monthly_spending_budget ) - return super().full_clean(*args, **kwargs) - - def _get_default_auto_recharge_balance_threshold(self): - # 25% of the monthly credit subscription - threshold = int(PricingPlan.from_sub(self).credits * 0.25) - return nearest_choice( - settings.AUTO_RECHARGE_BALANCE_THRESHOLD_CHOICES, threshold - ) - - def _get_default_monthly_spending_budget(self): - # 3x the monthly subscription charge - return 3 * PricingPlan.from_sub(self).monthly_charge - - def _get_default_monthly_spending_notification_threshold(self): - # 80% of the monthly budget - return int(0.8 * self._get_default_monthly_spending_budget()) - @property def has_user(self) -> bool: try: @@ -125,10 +134,26 @@ def has_user(self) -> bool: else: return True + def is_paid(self) -> bool: + return PricingPlan.from_sub(self).monthly_charge > 0 and self.external_id + def cancel(self): + from payments.webhooks import StripeWebhookHandler + + if not self.is_paid(): + return + match self.payment_provider: case PaymentProvider.STRIPE: - stripe.Subscription.cancel(self.external_id) + try: + stripe.Subscription.cancel(self.external_id) + except stripe.error.InvalidRequestError as e: + if e.code == "resource_missing": + StripeWebhookHandler.handle_subscription_cancelled( + self.user.uid + ) + else: + raise case PaymentProvider.PAYPAL: paypal.Subscription.retrieve(self.external_id).cancel() case _: @@ -156,16 +181,6 @@ def get_next_invoice_timestamp(self) -> float | None: def get_payment_method_summary(self) -> PaymentMethodSummary | None: match self.payment_provider: - case PaymentProvider.STRIPE: - pm = self.stripe_get_default_payment_method() - if not pm: - return None - return PaymentMethodSummary( - payment_method_type=pm.type, - card_brand=pm.card and pm.card.brand, - card_last4=pm.card and pm.card.last4, - billing_email=(pm.billing_details and pm.billing_details.email), - ) case PaymentProvider.PAYPAL: subscription = paypal.Subscription.retrieve(self.external_id) subscriber = subscription.subscriber @@ -178,20 +193,43 @@ def get_payment_method_summary(self) -> PaymentMethodSummary | None: card_last4=source.get("card", {}).get("last_digits"), billing_email=subscriber.email_address, ) + case PaymentProvider.STRIPE: + pm = self.stripe_get_default_payment_method() + if not pm: + # clear the payment provider if the default payment method is missing + if self.payment_provider and not self.is_paid(): + self.payment_provider = None + self.save(update_fields=["payment_provider"]) + return None + return PaymentMethodSummary( + payment_method_type=pm.type, + card_brand=( + (pm.type == "card" and pm.card and pm.card.brand) or pm.type + ), + card_last4=(pm.type == "card" and pm.card and pm.card.last4) or "", + billing_email=(pm.billing_details and pm.billing_details.email), + ) def stripe_get_default_payment_method(self) -> stripe.PaymentMethod | None: if self.payment_provider != PaymentProvider.STRIPE: - raise ValueError("Invalid Payment Provider") - - subscription = stripe.Subscription.retrieve(self.external_id) - if subscription.default_payment_method: - return stripe.PaymentMethod.retrieve(subscription.default_payment_method) + return None - customer = stripe.Customer.retrieve(subscription.customer) - if customer.invoice_settings.default_payment_method: - return stripe.PaymentMethod.retrieve( - customer.invoice_settings.default_payment_method + if self.external_id: + subscription = stripe.Subscription.retrieve( + self.external_id, expand=["default_payment_method"] ) + if subscription.default_payment_method: + return subscription.default_payment_method + + customer_id = self.stripe_get_customer_id() + customer = stripe.Customer.retrieve( + customer_id, expand=["invoice_settings.default_payment_method"] + ) + if ( + customer.invoice_settings + and customer.invoice_settings.default_payment_method + ): + return customer.invoice_settings.default_payment_method return None @@ -213,7 +251,11 @@ def stripe_get_or_create_auto_invoice( customer=customer_id, collection_method="charge_automatically", ) - invoices = [inv for inv in invoices.data if metadata_key in inv.metadata] + invoices = [ + inv + for inv in invoices.data + if inv.metadata and metadata_key in inv.metadata + ] open_invoice = next((inv for inv in invoices if inv.status == "open"), None) if open_invoice: @@ -258,10 +300,11 @@ def stripe_create_auto_invoice(self, *, amount_in_dollars: int, metadata_key: st return invoice def stripe_get_customer_id(self) -> str: - if self.payment_provider == PaymentProvider.STRIPE: + if self.payment_provider == PaymentProvider.STRIPE and self.external_id: subscription = stripe.Subscription.retrieve(self.external_id) return subscription.customer - raise ValueError("Invalid Payment Provider") + else: + return self.user.get_or_create_stripe_customer().id def stripe_attempt_addon_purchase(self, amount_in_dollars: int) -> bool: from payments.webhooks import StripeWebhookHandler @@ -273,6 +316,8 @@ def stripe_attempt_addon_purchase(self, amount_in_dollars: int) -> bool: if invoice.status != "open": return False pm = self.stripe_get_default_payment_method() + if not pm: + return False invoice = invoice.pay(payment_method=pm) if not invoice.paid: return False @@ -286,12 +331,6 @@ def get_external_management_url(self) -> str: from routers.account import account_route match self.payment_provider: - case PaymentProvider.STRIPE: - portal = stripe.billing_portal.Session.create( - customer=self.stripe_get_customer_id(), - return_url=get_app_route_url(account_route), - ) - return portal.url case PaymentProvider.PAYPAL: return str( settings.PAYPAL_WEB_BASE_URL @@ -300,6 +339,12 @@ def get_external_management_url(self) -> str: / "connect" / self.external_id ) + case PaymentProvider.STRIPE: + portal = stripe.billing_portal.Session.create( + customer=self.stripe_get_customer_id(), + return_url=get_app_route_url(account_route), + ) + return portal.url case _: raise NotImplementedError( f"Can't get management URL for subscription with provider {self.payment_provider}" @@ -328,6 +373,10 @@ def should_send_monthly_spending_notification(self) -> bool: ) -def nearest_choice(choices: list[int], value: int) -> int: - # nearest value in choices that is less than or equal to value - return min(filter(lambda x: x <= value, choices), key=lambda x: abs(x - value)) +def nearest_choice(choices: list[int], value: float) -> int: + # nearest choice that is less than or equal to the value (or the minimum choice if value is the least) + return min( + filter(lambda x: x <= value, choices), + key=lambda x: abs(x - value), + default=min(choices), + ) diff --git a/payments/webhooks.py b/payments/webhooks.py index b30cae120..0b822cfe7 100644 --- a/payments/webhooks.py +++ b/payments/webhooks.py @@ -1,8 +1,14 @@ +from copy import copy + import stripe from django.db import transaction from loguru import logger -from app_users.models import AppUser, PaymentProvider, TransactionReason +from app_users.models import ( + AppUser, + PaymentProvider, + TransactionReason, +) from daras_ai_v2 import paypal from .models import Subscription from .plans import PricingPlan @@ -59,7 +65,7 @@ def handle_subscription_updated(cls, pp_sub: paypal.Subscription): ) return - _set_user_subscription( + set_user_subscription( provider=cls.PROVIDER, plan=plan, uid=pp_sub.custom_id, @@ -69,8 +75,11 @@ def handle_subscription_updated(cls, pp_sub: paypal.Subscription): @classmethod def handle_subscription_cancelled(cls, pp_sub: paypal.Subscription): assert pp_sub.custom_id, f"PayPal subscription {pp_sub.id} is missing uid" - _remove_subscription_for_user( - provider=cls.PROVIDER, uid=pp_sub.custom_id, external_id=pp_sub.id + set_user_subscription( + uid=pp_sub.custom_id, + plan=PricingPlan.STARTER, + provider=None, + external_id=None, ) @@ -79,6 +88,8 @@ class StripeWebhookHandler: @classmethod def handle_invoice_paid(cls, uid: str, invoice: stripe.Invoice): + from app_users.tasks import save_stripe_default_payment_method + kwargs = {} if invoice.subscription: kwargs["plan"] = PricingPlan.get_by_key( @@ -97,32 +108,48 @@ def handle_invoice_paid(cls, uid: str, invoice: stripe.Invoice): reason = TransactionReason.AUTO_RECHARGE else: reason = TransactionReason.ADDON + + amount = invoice.lines.data[0].quantity + charged_amount = invoice.lines.data[0].amount add_balance_for_payment( uid=uid, - amount=invoice.lines.data[0].quantity, + amount=amount, invoice_id=invoice.id, payment_provider=cls.PROVIDER, - charged_amount=invoice.lines.data[0].amount, + charged_amount=charged_amount, reason=reason, **kwargs, ) + save_stripe_default_payment_method.delay( + payment_intent_id=invoice.payment_intent, + uid=uid, + amount=amount, + charged_amount=charged_amount, + reason=reason, + ) + @classmethod def handle_checkout_session_completed(cls, uid: str, session_data): - if setup_intent_id := session_data.get("setup_intent") is None: + setup_intent_id = session_data.get("setup_intent") + if not setup_intent_id: # not a setup mode checkout -- do nothing return - setup_intent = stripe.SetupIntent.retrieve(setup_intent_id) - - # subscription_id was passed to metadata when creating the session - sub_id = setup_intent.metadata["subscription_id"] - assert ( - sub_id - ), f"subscription_id is missing in setup_intent metadata {setup_intent}" - stripe.Subscription.modify( - sub_id, default_payment_method=setup_intent.payment_method - ) + setup_intent = stripe.SetupIntent.retrieve(setup_intent_id) + if sub_id := setup_intent.metadata.get("subscription_id"): + # subscription_id was passed to metadata when creating the session + stripe.Subscription.modify( + sub_id, default_payment_method=setup_intent.payment_method + ) + elif customer_id := session_data.get("customer"): + # no subscription_id, so update the customer's default payment method instead + stripe.Customer.modify( + customer_id, + invoice_settings=dict( + default_payment_method=setup_intent.payment_method + ), + ) @classmethod def handle_subscription_updated(cls, uid: str, stripe_sub: stripe.Subscription): @@ -146,7 +173,7 @@ def handle_subscription_updated(cls, uid: str, stripe_sub: stripe.Subscription): ) return - _set_user_subscription( + set_user_subscription( provider=cls.PROVIDER, plan=plan, uid=uid, @@ -154,10 +181,12 @@ def handle_subscription_updated(cls, uid: str, stripe_sub: stripe.Subscription): ) @classmethod - def handle_subscription_cancelled(cls, uid: str, stripe_sub): - logger.info(f"Stripe subscription cancelled: {stripe_sub.id}") - _remove_subscription_for_user( - provider=cls.PROVIDER, uid=uid, external_id=stripe_sub.id + def handle_subscription_cancelled(cls, uid: str): + set_user_subscription( + uid=uid, + plan=PricingPlan.STARTER, + provider=PaymentProvider.STRIPE, + external_id=None, ) @@ -190,44 +219,37 @@ def add_balance_for_payment( send_monthly_spending_notification_email.delay(user.id) -def _set_user_subscription( - *, provider: PaymentProvider, plan: PricingPlan, uid: str, external_id: str -): +def set_user_subscription( + *, + uid: str, + plan: PricingPlan, + provider: PaymentProvider | None, + external_id: str | None, + amount: int = None, + charged_amount: int = None, +) -> Subscription: with transaction.atomic(): - subscription, created = Subscription.objects.get_or_create( - payment_provider=provider, - external_id=external_id, - defaults=dict(plan=plan.db_value), - ) - subscription.plan = plan.db_value - subscription.full_clean() - subscription.save() - user = AppUser.objects.get_or_create_from_uid(uid)[0] - existing = user.subscription - - user.subscription = subscription - user.save(update_fields=["subscription"]) - if not existing: - return + old_sub = user.subscription + if old_sub: + new_sub = copy(old_sub) + else: + old_sub = None + new_sub = Subscription() - # cancel existing subscription if it's not the same as the new one - if existing.external_id != external_id: - existing.cancel() + new_sub.plan = plan.db_value + new_sub.payment_provider = provider + new_sub.external_id = external_id + new_sub.full_clean(amount=amount, charged_amount=charged_amount) + new_sub.save() - # delete old db record if it exists - if existing.id != subscription.id: - existing.delete() + if not old_sub: + user.subscription = new_sub + user.save(update_fields=["subscription"]) + # cancel previous subscription if it's not the same as the new one + if old_sub and old_sub.external_id != external_id: + old_sub.cancel() -def _remove_subscription_for_user( - *, uid: str, provider: PaymentProvider, external_id: str -): - AppUser.objects.filter( - uid=uid, - subscription__payment_provider=provider, - subscription__external_id=external_id, - ).update( - subscription=None, - ) + return new_sub diff --git a/routers/paypal.py b/routers/paypal.py index 797812fae..933a2a9a7 100644 --- a/routers/paypal.py +++ b/routers/paypal.py @@ -126,7 +126,7 @@ def create_subscription(request: Request, payload: dict = fastapi_request_json): if plan.deprecated: return JSONResponse({"error": "Deprecated plan"}, status_code=400) - if request.user.subscription: + if request.user.subscription and request.user.subscription.is_paid(): return JSONResponse( {"error": "User already has an active subscription"}, status_code=400 ) diff --git a/routers/stripe.py b/routers/stripe.py index 3d481d47b..04538bd55 100644 --- a/routers/stripe.py +++ b/routers/stripe.py @@ -46,6 +46,6 @@ def webhook_received(request: Request, payload: bytes = fastapi_request_body): case "customer.subscription.created" | "customer.subscription.updated": StripeWebhookHandler.handle_subscription_updated(uid, data) case "customer.subscription.deleted": - StripeWebhookHandler.handle_subscription_cancelled(uid, data) + StripeWebhookHandler.handle_subscription_cancelled(uid) return JSONResponse({"status": "success"}) diff --git a/scripts/migrate_existing_subscriptions.py b/scripts/migrate_existing_subscriptions.py index 967b5bf73..ceb7dc224 100644 --- a/scripts/migrate_existing_subscriptions.py +++ b/scripts/migrate_existing_subscriptions.py @@ -29,7 +29,7 @@ # "quantity": 1000, # number of credits (set by html) "adjustable_quantity": { "enabled": True, - "maximum": 50_000, + "maximum": 100_000, "minimum": 1_000, }, }, diff --git a/tests/test_checkout.py b/tests/test_checkout.py index 21a9a3291..a392c2cbc 100644 --- a/tests/test_checkout.py +++ b/tests/test_checkout.py @@ -3,7 +3,7 @@ from app_users.models import AppUser from daras_ai_v2 import settings -from daras_ai_v2.billing import stripe_subscription_checkout_redirect +from daras_ai_v2.billing import stripe_subscription_create from gooey_gui import RedirectException from payments.plans import PricingPlan from server import app @@ -20,4 +20,4 @@ def test_create_checkout_session( return with pytest.raises(RedirectException): - stripe_subscription_checkout_redirect(force_authentication, plan) + stripe_subscription_create(force_authentication, plan)