diff --git a/app_users/apps.py b/app_users/apps.py
index 023d4ae99..e8d750128 100644
--- a/app_users/apps.py
+++ b/app_users/apps.py
@@ -5,8 +5,3 @@ class AppUsersConfig(AppConfig):
default_auto_field = "django.db.models.BigAutoField"
name = "app_users"
verbose_name = "App Users"
-
- def ready(self):
- from . import signals
-
- assert signals
diff --git a/app_users/signals.py b/app_users/signals.py
deleted file mode 100644
index 8e6b32dcf..000000000
--- a/app_users/signals.py
+++ /dev/null
@@ -1,19 +0,0 @@
-# from django.db import transaction
-# from django.db.models.signals import post_delete
-# from django.dispatch import receiver
-# from firebase_admin import auth
-#
-# from app_users.models import AppUser
-#
-#
-# @receiver(post_delete, sender=AppUser)
-# def profile_post_delete(instance: AppUser, **kwargs):
-# if not instance.uid:
-# return
-#
-# @transaction.on_commit
-# def _():
-# try:
-# auth.delete_user(instance.uid)
-# except auth.UserNotFoundError:
-# pass
diff --git a/app_users/tasks.py b/app_users/tasks.py
new file mode 100644
index 000000000..0327ac423
--- /dev/null
+++ b/app_users/tasks.py
@@ -0,0 +1,54 @@
+import stripe
+from loguru import logger
+
+from app_users.models import PaymentProvider, TransactionReason
+from celeryapp.celeryconfig import app
+from payments.models import Subscription
+from payments.plans import PricingPlan
+from payments.webhooks import set_user_subscription
+
+
+@app.task
+def save_stripe_default_payment_method(
+ *,
+ payment_intent_id: str,
+ uid: str,
+ amount: int,
+ charged_amount: int,
+ reason: TransactionReason,
+):
+ pi = stripe.PaymentIntent.retrieve(payment_intent_id, expand=["payment_method"])
+ pm = pi.payment_method
+ if not (pm and pm.customer):
+ logger.error(
+ f"Failed to retrieve payment method for payment intent {payment_intent_id}"
+ )
+ return
+
+ # update customer's defualt payment method
+ # note: if a customer has an active subscription, the payment method attached there will be preferred
+ # see `stripe_get_default_payment_method` in payments/models.py module
+ logger.info(
+ f"Updating default payment method for customer {pm.customer} to {pm.id}"
+ )
+ stripe.Customer.modify(
+ pm.customer,
+ invoice_settings=dict(default_payment_method=pm),
+ )
+
+ # if user doesn't already have a active billing/autorecharge info, so we don't need to do anything
+ # set user's subscription to the free plan
+ if (
+ reason == TransactionReason.ADDON
+ and not Subscription.objects.filter(
+ user__uid=uid, payment_provider__isnull=False
+ ).exists()
+ ):
+ set_user_subscription(
+ uid=uid,
+ plan=PricingPlan.STARTER,
+ provider=PaymentProvider.STRIPE,
+ external_id=None,
+ amount=amount,
+ charged_amount=charged_amount,
+ )
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index fbbc7e60b..9e1c679fb 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -52,6 +52,7 @@
from daras_ai_v2.exceptions import InsufficientCredits
from daras_ai_v2.fastapi_tricks import get_route_path
from daras_ai_v2.grid_layout_widget import grid_layout
+from daras_ai_v2.gui_confirm import confirm_modal
from daras_ai_v2.html_spinner_widget import html_spinner
from daras_ai_v2.manage_api_keys_widget import manage_api_keys
from daras_ai_v2.meta_preview_url import meta_preview_url
@@ -697,10 +698,29 @@ def _render_options_modal(
save_as_new_button = gui.button(
f"{save_as_new_icon} Save as New", className="w-100"
)
- delete_button = not published_run.is_root() and gui.button(
- f' Delete',
- className="w-100 text-danger",
- )
+
+ if not published_run.is_root():
+ confirm_delete_modal, confirmed = confirm_modal(
+ title="Are you sure?",
+ key="--delete-run-modal",
+ text=f"""
+Are you sure you want to delete this published run?
+
+**{published_run.title}**
+
+This will also delete all the associated versions.
+ """,
+ button_label="Delete",
+ button_class="border-danger bg-danger text-white",
+ )
+ if gui.button(
+ f' Delete',
+ className="w-100 text-danger",
+ ):
+ confirm_delete_modal.open()
+ if confirmed:
+ published_run.delete()
+ raise gui.RedirectException(self.app_url())
if duplicate_button:
duplicate_pr = self.duplicate_published_run(
@@ -730,47 +750,6 @@ def _render_options_modal(
gui.write("#### Version History", className="mb-4")
self._render_version_history()
- confirm_delete_modal = gui.Modal("Confirm Delete", key="confirm-delete-modal")
- if delete_button:
- confirm_delete_modal.open()
- if confirm_delete_modal.is_open():
- modal.empty()
- with confirm_delete_modal.container():
- self._render_confirm_delete_modal(
- published_run=published_run,
- modal=confirm_delete_modal,
- )
-
- def _render_confirm_delete_modal(
- self,
- *,
- published_run: PublishedRun,
- modal: gui.Modal,
- ):
- gui.write(
- "Are you sure you want to delete this published run? "
- f"_({published_run.title})_"
- )
- gui.caption("This will also delete all the associated versions.")
- with gui.div(className="d-flex"):
- confirm_button = gui.button(
- 'Confirm',
- type="secondary",
- className="w-100",
- )
- cancel_button = gui.button(
- "Cancel",
- type="secondary",
- className="w-100",
- )
-
- if confirm_button:
- published_run.delete()
- raise gui.RedirectException(self.app_url())
-
- if cancel_button:
- modal.close()
-
def _render_admin_options(self, current_run: SavedRun, published_run: PublishedRun):
if (
not self.is_current_user_admin()
diff --git a/daras_ai_v2/billing.py b/daras_ai_v2/billing.py
index 09b48d375..25e607415 100644
--- a/daras_ai_v2/billing.py
+++ b/daras_ai_v2/billing.py
@@ -1,5 +1,3 @@
-from typing import Literal
-
import gooey_gui as gui
import stripe
from django.core.exceptions import ValidationError
@@ -8,22 +6,21 @@
from daras_ai_v2 import icons, settings, paypal
from daras_ai_v2.fastapi_tricks import get_app_route_url
from daras_ai_v2.grid_layout_widget import grid_layout
+from daras_ai_v2.gui_confirm import confirm_modal
from daras_ai_v2.settings import templates
from daras_ai_v2.user_date_widgets import render_local_date_attrs
from payments.models import PaymentMethodSummary
from payments.plans import PricingPlan
+from payments.webhooks import StripeWebhookHandler, set_user_subscription
from scripts.migrate_existing_subscriptions import available_subscriptions
rounded_border = "w-100 border shadow-sm rounded py-4 px-3"
-PlanActionLabel = Literal["Upgrade", "Downgrade", "Contact Us", "Your Plan"]
-
-
def billing_page(user: AppUser):
render_payments_setup()
- if user.subscription:
+ if user.subscription and user.subscription.is_paid():
render_current_plan(user)
with gui.div(className="my-5"):
@@ -35,10 +32,11 @@ def billing_page(user: AppUser):
with gui.div(className="my-5"):
render_addon_section(user, selected_payment_provider)
- if user.subscription and user.subscription.payment_provider:
+ if user.subscription:
if user.subscription.payment_provider == PaymentProvider.STRIPE:
with gui.div(className="my-5"):
render_auto_recharge_section(user)
+
with gui.div(className="my-5"):
render_payment_information(user)
@@ -59,11 +57,10 @@ def render_payments_setup():
def render_current_plan(user: AppUser):
plan = PricingPlan.from_sub(user.subscription)
- provider = (
- PaymentProvider(user.subscription.payment_provider)
- if user.subscription.payment_provider
- else None
- )
+ if user.subscription.payment_provider:
+ provider = PaymentProvider(user.subscription.payment_provider)
+ else:
+ provider = None
with gui.div(className=f"{rounded_border} border-dark"):
# ROW 1: Plan title and next invoice date
@@ -101,7 +98,10 @@ def render_current_plan(user: AppUser):
with left:
gui.write(f"# {plan.pricing_title()}", className="no-margin")
if plan.monthly_charge:
- provider_text = f" **via {provider.label}**" if provider else ""
+ if provider:
+ provider_text = f" **via {provider.label}**"
+ else:
+ provider_text = ""
gui.caption("per month" + provider_text)
with right, gui.div(className="text-end"):
@@ -131,7 +131,7 @@ def render_all_plans(user: AppUser) -> PaymentProvider:
plans_div = gui.div(className="mb-1")
if user.subscription and user.subscription.payment_provider:
- selected_payment_provider = None
+ selected_payment_provider = user.subscription.payment_provider
else:
with gui.div():
selected_payment_provider = PaymentProvider[
@@ -149,7 +149,10 @@ def _render_plan(plan: PricingPlan):
):
_render_plan_details(plan)
_render_plan_action_button(
- user, plan, current_plan, selected_payment_provider
+ user=user,
+ plan=plan,
+ current_plan=current_plan,
+ payment_provider=selected_payment_provider,
)
with plans_div:
@@ -198,30 +201,59 @@ def _render_plan_action_button(
className=btn_classes + " btn btn-theme btn-primary",
):
gui.html("Contact Us")
- elif user.subscription and not user.subscription.payment_provider:
+ elif (
+ user.subscription and user.subscription.plan == PricingPlan.ENTERPRISE.db_value
+ ):
# don't show upgrade/downgrade buttons for enterprise customers
- # assumption: anyone without a payment provider attached is admin/enterprise
return
else:
- if plan.credits > current_plan.credits:
- label, btn_type = ("Upgrade", "primary")
- else:
- label, btn_type = ("Downgrade", "secondary")
-
- if user.subscription and user.subscription.payment_provider:
+ if user.subscription and user.subscription.is_paid():
# subscription exists, show upgrade/downgrade button
- _render_update_subscription_button(
- label,
- user=user,
- current_plan=current_plan,
- plan=plan,
- className=f"{btn_classes} btn btn-theme btn-{btn_type}",
- )
+ if plan.credits > current_plan.credits:
+ modal, confirmed = confirm_modal(
+ title="Upgrade Plan",
+ key=f"--modal-{plan.key}",
+ text=f"""
+Are you sure you want to upgrade from: **{current_plan.title} ({fmt_price(current_plan)})** to **{plan.title} ({fmt_price(plan)})**?
+
+This will charge you the full amount today, and every month thereafter.
+
+**{current_plan.credits:,} credits** will be added to your account.
+ """,
+ button_label="Buy",
+ )
+ if gui.button(
+ "Upgrade", className="primary", key=f"--change-sub-{plan.key}"
+ ):
+ modal.open()
+ if confirmed:
+ change_subscription(
+ user,
+ plan,
+ # when upgrading, charge the full new amount today: https://docs.stripe.com/billing/subscriptions/billing-cycle#reset-the-billing-cycle-to-the-current-time
+ billing_cycle_anchor="now",
+ )
+ else:
+ modal, confirmed = confirm_modal(
+ title="Downgrade Plan",
+ key=f"--modal-{plan.key}",
+ text=f"""
+Are you sure you want to downgrade from: **{current_plan.title} ({fmt_price(current_plan)})** to **{plan.title} ({fmt_price(plan)})**?
+
+This will take effect from the next billing cycle.
+ """,
+ button_label="Downgrade",
+ button_class="border-danger bg-danger text-white",
+ )
+ if gui.button(
+ "Downgrade", className="secondary", key=f"--change-sub-{plan.key}"
+ ):
+ modal.open()
+ if confirmed:
+ change_subscription(user, plan)
else:
assert payment_provider is not None # for sanity
_render_create_subscription_button(
- label,
- btn_type=btn_type,
user=user,
plan=plan,
payment_provider=payment_provider,
@@ -229,81 +261,18 @@ def _render_plan_action_button(
def _render_create_subscription_button(
- label: PlanActionLabel,
*,
- btn_type: str,
user: AppUser,
plan: PricingPlan,
payment_provider: PaymentProvider,
):
match payment_provider:
case PaymentProvider.STRIPE:
- key = f"stripe-sub-{plan.key}"
- render_stripe_subscription_button(
- user=user,
- label=label,
- plan=plan,
- btn_type=btn_type,
- key=key,
- )
+ render_stripe_subscription_button(user=user, plan=plan)
case PaymentProvider.PAYPAL:
render_paypal_subscription_button(plan=plan)
-def _render_update_subscription_button(
- label: PlanActionLabel,
- *,
- user: AppUser,
- current_plan: PricingPlan,
- plan: PricingPlan,
- className: str = "",
-):
- key = f"change-sub-{plan.key}"
- match label:
- case "Downgrade":
- downgrade_modal = gui.Modal(
- "Confirm downgrade",
- key=f"downgrade-plan-modal-{plan.key}",
- )
- if gui.button(
- label,
- className=className,
- key=key,
- ):
- downgrade_modal.open()
-
- if downgrade_modal.is_open():
- with downgrade_modal.container():
- gui.write(
- f"""
- Are you sure you want to change from:
- **{current_plan.title} ({fmt_price(current_plan)})** to **{plan.title} ({fmt_price(plan)})**?
- """,
- className="d-block py-4",
- )
- with gui.div(className="d-flex w-100"):
- if gui.button(
- "Downgrade",
- className="btn btn-theme bg-danger border-danger text-white",
- key=f"{key}-confirm",
- ):
- change_subscription(user, plan)
- if gui.button(
- "Cancel",
- className="border border-danger text-danger",
- key=f"{key}-cancel",
- ):
- downgrade_modal.close()
- case _:
- if gui.button(label, className=className, key=key):
- change_subscription(
- user,
- plan,
- # when upgrading, charge the full new amount today: https://docs.stripe.com/billing/subscriptions/billing-cycle#reset-the-billing-cycle-to-the-current-time
- billing_cycle_anchor="now",
- )
-
-
def fmt_price(plan: PricingPlan) -> str:
if plan.monthly_charge:
return f"${plan.monthly_charge:,}/month"
@@ -322,7 +291,6 @@ def change_subscription(user: AppUser, new_plan: PricingPlan, **kwargs):
if new_plan == PricingPlan.STARTER:
user.subscription.cancel()
- user.subscription.delete()
raise gui.RedirectException(
get_app_route_url(payment_processing_route), status_code=303
)
@@ -383,7 +351,7 @@ def render_addon_section(user: AppUser, selected_payment_provider: PaymentProvid
gui.write("# Purchase Credits")
gui.caption(f"Buy more credits. $1 per {settings.ADDON_CREDITS_PER_DOLLAR} credits")
- if user.subscription:
+ if user.subscription and user.subscription.payment_provider:
provider = PaymentProvider(user.subscription.payment_provider)
else:
provider = selected_payment_provider
@@ -414,57 +382,66 @@ def render_paypal_addon_buttons():
def render_stripe_addon_buttons(user: AppUser):
+ if not (user.subscription and user.subscription.payment_provider):
+ save_pm = gui.checkbox(
+ "Save payment method for future purchases & auto-recharge", value=True
+ )
+ else:
+ save_pm = True
+
for dollat_amt in settings.ADDON_AMOUNT_CHOICES:
- render_stripe_addon_button(dollat_amt, user)
+ render_stripe_addon_button(dollat_amt, user, save_pm)
+ error = gui.session_state.pop("--addon-purchase-error", None)
+ if error:
+ gui.error(error)
-def render_stripe_addon_button(dollat_amt: int, user: AppUser):
- confirm_purchase_modal = gui.Modal(
- "Confirm Purchase", key=f"confirm-purchase-{dollat_amt}"
+
+def render_stripe_addon_button(dollat_amt: int, user: AppUser, save_pm: bool):
+ modal, confirmed = confirm_modal(
+ title="Purchase Credits",
+ key=f"--addon-modal-{dollat_amt}",
+ text=f"""
+Please confirm your purchase: **{dollat_amt * settings.ADDON_CREDITS_PER_DOLLAR:,} credits for ${dollat_amt}**.
+
+This is a one-time purchase. Your account will be credited immediately.
+ """,
+ button_label="Buy",
+ text_on_confirm="Processing Payment...",
)
+
if gui.button(f"${dollat_amt:,}", type="primary"):
- if user.subscription:
- confirm_purchase_modal.open()
+ if user.subscription and user.subscription.stripe_get_default_payment_method():
+ modal.open()
else:
- stripe_addon_checkout_redirect(user, dollat_amt)
+ stripe_addon_checkout_redirect(user, dollat_amt, save_pm)
- if not confirm_purchase_modal.is_open():
- return
- with confirm_purchase_modal.container():
- gui.write(
- f"""
- Please confirm your purchase:
- **{dollat_amt * settings.ADDON_CREDITS_PER_DOLLAR:,} credits for ${dollat_amt}**.
- """,
- className="py-4 d-block text-center",
+ if confirmed:
+ success = gui.run_in_thread(
+ user.subscription.stripe_attempt_addon_purchase,
+ args=[dollat_amt],
+ placeholder="",
)
- with gui.div(className="d-flex w-100 justify-content-end"):
- if gui.session_state.get("--confirm-purchase"):
- success = gui.run_in_thread(
- user.subscription.stripe_attempt_addon_purchase,
- args=[dollat_amt],
- placeholder="Processing payment...",
- )
- if success is None:
- return
- gui.session_state.pop("--confirm-purchase")
- if success:
- confirm_purchase_modal.close()
- else:
- gui.error("Payment failed... Please try again.")
- return
-
- if gui.button("Cancel", className="border border-danger text-danger me-2"):
- confirm_purchase_modal.close()
- gui.button("Buy", type="primary", key="--confirm-purchase")
+ if success is None:
+ return
+ if not success:
+ gui.session_state["--addon-purchase-error"] = (
+ "Payment failed... Please try again or contact us at support@gooey.ai"
+ )
+ modal.close()
-def stripe_addon_checkout_redirect(user: AppUser, dollat_amt: int):
+def stripe_addon_checkout_redirect(user: AppUser, dollat_amt: int, save_pm: bool):
from routers.account import account_route
from routers.account import payment_processing_route
line_item = available_subscriptions["addon"]["stripe"].copy()
line_item["quantity"] = dollat_amt * settings.ADDON_CREDITS_PER_DOLLAR
+ kwargs = {}
+ if save_pm:
+ kwargs["payment_intent_data"] = {"setup_future_usage": "on_session"}
+ else:
+ kwargs["saved_payment_method_options"] = {"payment_method_save": "enabled"}
checkout_session = stripe.checkout.Session.create(
line_items=[line_item],
mode="payment",
@@ -473,55 +450,98 @@ def stripe_addon_checkout_redirect(user: AppUser, dollat_amt: int):
customer=user.get_or_create_stripe_customer(),
invoice_creation={"enabled": True},
allow_promotion_codes=True,
- saved_payment_method_options={
- "payment_method_save": "enabled",
- },
+ **kwargs,
)
raise gui.RedirectException(checkout_session.url, status_code=303)
def render_stripe_subscription_button(
*,
- label: str,
user: AppUser,
plan: PricingPlan,
- btn_type: str,
- key: str,
):
if not plan.supports_stripe():
gui.write("Stripe subscription not available")
return
+ modal, confirmed = confirm_modal(
+ title="Upgrade Plan",
+ key=f"--modal-{plan.key}",
+ text=f"""
+Are you sure you want to subscribe to **{plan.title} ({fmt_price(plan)})**?
+
+This will charge you the full amount today, and every month thereafter.
+
+**{plan.credits:,} credits** will be added to your account.
+ """,
+ button_label="Buy",
+ )
+
# IMPORTANT: key=... is needed here to maintain uniqueness
# of buttons with the same label. otherwise, all buttons
# will be the same to the server
- if gui.button(label, key=key, type=btn_type):
- stripe_subscription_checkout_redirect(user=user, plan=plan)
+ if gui.button(
+ "Upgrade",
+ key=f"--change-sub-{plan.key}",
+ type="primary",
+ ):
+ if user.subscription and user.subscription.stripe_get_default_payment_method():
+ modal.open()
+ else:
+ stripe_subscription_create(user=user, plan=plan)
+ if confirmed:
+ stripe_subscription_create(user=user, plan=plan)
-def stripe_subscription_checkout_redirect(user: AppUser, plan: PricingPlan):
+
+def stripe_subscription_create(user: AppUser, plan: PricingPlan):
from routers.account import account_route
from routers.account import payment_processing_route
- if user.subscription:
+ if user.subscription and user.subscription.plan == plan.db_value:
# already subscribed to some plan
return
+ pm = user.subscription and user.subscription.stripe_get_default_payment_method()
metadata = {settings.STRIPE_USER_SUBSCRIPTION_METADATA_FIELD: plan.key}
- checkout_session = stripe.checkout.Session.create(
- line_items=[(plan.get_stripe_line_item())],
- mode="subscription",
- success_url=get_app_route_url(payment_processing_route),
- cancel_url=get_app_route_url(account_route),
- customer=user.get_or_create_stripe_customer(),
- metadata=metadata,
- subscription_data={"metadata": metadata},
- allow_promotion_codes=True,
- saved_payment_method_options={
- "payment_method_save": "enabled",
- },
- )
- raise gui.RedirectException(checkout_session.url, status_code=303)
+ line_items = [plan.get_stripe_line_item()]
+ if pm:
+ # directly create the subscription without checkout
+ stripe.Subscription.create(
+ customer=pm.customer,
+ items=line_items,
+ metadata=metadata,
+ default_payment_method=pm.id,
+ proration_behavior="none",
+ )
+ raise gui.RedirectException(
+ get_app_route_url(payment_processing_route), status_code=303
+ )
+ else:
+ # check for existing subscriptions
+ customer = user.get_or_create_stripe_customer()
+ for sub in stripe.Subscription.list(
+ customer=customer, status="active", limit=1
+ ).data:
+ StripeWebhookHandler.handle_subscription_updated(
+ uid=user.uid, stripe_sub=sub
+ )
+ raise gui.RedirectException(
+ get_app_route_url(payment_processing_route), status_code=303
+ )
+
+ checkout_session = stripe.checkout.Session.create(
+ mode="subscription",
+ success_url=get_app_route_url(payment_processing_route),
+ cancel_url=get_app_route_url(account_route),
+ allow_promotion_codes=True,
+ customer=customer,
+ line_items=line_items,
+ metadata=metadata,
+ subscription_data={"metadata": metadata},
+ saved_payment_method_options={"payment_method_save": "enabled"},
+ )
+ raise gui.RedirectException(checkout_session.url, status_code=303)
def render_paypal_subscription_button(
@@ -544,44 +564,90 @@ def render_paypal_subscription_button(
def render_payment_information(user: AppUser):
- assert user.subscription
-
- gui.write("## Payment Information", id="payment-information", className="d-block")
- col1, col2, col3 = gui.columns(3, responsive=False)
- with col1:
- gui.write("**Pay via**")
- with col2:
- provider = PaymentProvider(user.subscription.payment_provider)
- gui.write(provider.label)
- with col3:
- if gui.button(f"{icons.edit} Edit", type="link", key="manage-payment-provider"):
- raise gui.RedirectException(user.subscription.get_external_management_url())
+ if not user.subscription:
+ return
pm_summary = gui.run_in_thread(
user.subscription.get_payment_method_summary, cache=True
)
if not pm_summary:
return
- pm_summary = PaymentMethodSummary(*pm_summary)
- if pm_summary.card_brand and pm_summary.card_last4:
+
+ gui.write("## Payment Information", id="payment-information", className="d-block")
+ with gui.div(className="ps-1"):
col1, col2, col3 = gui.columns(3, responsive=False)
with col1:
- gui.write("**Payment Method**")
+ gui.write("**Pay via**")
with col2:
- gui.write(
- f"{format_card_brand(pm_summary.card_brand)} ending in {pm_summary.card_last4}",
- unsafe_allow_html=True,
+ provider = PaymentProvider(
+ user.subscription.payment_provider or PaymentProvider.STRIPE
)
+ gui.write(provider.label)
with col3:
- if gui.button(f"{icons.edit} Edit", type="link", key="edit-payment-method"):
- change_payment_method(user)
+ if gui.button(
+ f"{icons.edit} Edit", type="link", key="manage-payment-provider"
+ ):
+ raise gui.RedirectException(
+ user.subscription.get_external_management_url()
+ )
- if pm_summary.billing_email:
- col1, col2, _ = gui.columns(3, responsive=False)
- with col1:
- gui.write("**Billing Email**")
- with col2:
- gui.html(pm_summary.billing_email)
+ pm_summary = PaymentMethodSummary(*pm_summary)
+ if pm_summary.card_brand:
+ col1, col2, col3 = gui.columns(3, responsive=False)
+ with col1:
+ gui.write("**Payment Method**")
+ with col2:
+ if pm_summary.card_last4:
+ gui.write(
+ f"{format_card_brand(pm_summary.card_brand)} ending in {pm_summary.card_last4}",
+ unsafe_allow_html=True,
+ )
+ else:
+ gui.write(pm_summary.card_brand)
+ with col3:
+ if gui.button(
+ f"{icons.edit} Edit", type="link", key="edit-payment-method"
+ ):
+ change_payment_method(user)
+
+ if pm_summary.billing_email:
+ col1, col2, _ = gui.columns(3, responsive=False)
+ with col1:
+ gui.write("**Billing Email**")
+ with col2:
+ gui.html(pm_summary.billing_email)
+
+ from routers.account import payment_processing_route
+
+ modal, confirmed = confirm_modal(
+ title="Delete Payment Information",
+ key="--delete-payment-method",
+ text="""
+Are you sure you want to delete your payment information?
+
+This will cancel your subscription and remove your saved payment method.
+ """,
+ button_label="Delete",
+ button_class="border-danger bg-danger text-white",
+ )
+ if gui.button(
+ "Delete & Cancel Subscription",
+ className="border-danger text-danger",
+ ):
+ modal.open()
+ if confirmed:
+ set_user_subscription(
+ uid=user.uid,
+ plan=PricingPlan.STARTER,
+ provider=None,
+ external_id=None,
+ )
+ pm = user.subscription and user.subscription.stripe_get_default_payment_method()
+ if pm:
+ pm.detach()
+ raise gui.RedirectException(
+ get_app_route_url(payment_processing_route), status_code=303
+ )
def change_payment_method(user: AppUser):
@@ -636,10 +702,7 @@ def render_billing_history(user: AppUser, limit: int = 50):
def render_auto_recharge_section(user: AppUser):
- assert (
- user.subscription
- and user.subscription.payment_provider == PaymentProvider.STRIPE
- )
+ assert user.subscription
subscription = user.subscription
gui.write("## Auto Recharge & Limits")
diff --git a/daras_ai_v2/gui_confirm.py b/daras_ai_v2/gui_confirm.py
new file mode 100644
index 000000000..dd719d7a5
--- /dev/null
+++ b/daras_ai_v2/gui_confirm.py
@@ -0,0 +1,42 @@
+import gooey_gui as gui
+
+from daras_ai_v2.html_spinner_widget import html_spinner
+
+
+def confirm_modal(
+ *,
+ title: str,
+ key: str,
+ text: str,
+ button_label: str,
+ button_class: str = "",
+ text_on_confirm: str | None = None,
+) -> tuple[gui.Modal, bool]:
+ modal = gui.Modal(title, key=key)
+ confirmed_key = f"{key}-confirmed"
+ if modal.is_open():
+ with modal.container():
+ with gui.div(className="pt-4 pb-3"):
+ gui.write(text)
+ with gui.div(className="d-flex w-100 justify-content-end"):
+ confirmed = bool(gui.session_state.get(confirmed_key, None))
+ if confirmed and text_on_confirm:
+ html_spinner(text_on_confirm)
+ else:
+ if gui.button(
+ "Cancel",
+ type="tertiary",
+ className="me-2",
+ key=f"{key}-cancelled",
+ ):
+ modal.close()
+ confirmed = gui.button(
+ button_label,
+ type="primary",
+ key=confirmed_key,
+ className=button_class,
+ )
+ return modal, confirmed
+ else:
+ gui.session_state.pop(confirmed_key, None)
+ return modal, False
diff --git a/payments/admin.py b/payments/admin.py
index c889bc109..9fd8f910f 100644
--- a/payments/admin.py
+++ b/payments/admin.py
@@ -5,4 +5,14 @@
@admin.register(Subscription)
class SubscriptionAdmin(admin.ModelAdmin):
- search_fields = ["plan", "payment_provider", "external_id"]
+ search_fields = [
+ "plan",
+ "payment_provider",
+ "external_id",
+ ]
+ readonly_fields = [
+ "user",
+ "created_at",
+ "updated_at",
+ "get_payment_method_summary",
+ ]
diff --git a/payments/auto_recharge.py b/payments/auto_recharge.py
index 1aa414cfb..14d6ba49d 100644
--- a/payments/auto_recharge.py
+++ b/payments/auto_recharge.py
@@ -107,6 +107,9 @@ def _auto_recharge_user(user: AppUser):
# get default payment method and attempt payment
assert invoice.status == "open" # sanity check
pm = user.subscription.stripe_get_default_payment_method()
+ if not pm:
+ logger.warning(f"{user} has no default payment method, cannot auto-recharge")
+ return
try:
invoice_data = invoice.pay(payment_method=pm)
diff --git a/payments/models.py b/payments/models.py
index a7deadc5e..5a200ba1e 100644
--- a/payments/models.py
+++ b/payments/models.py
@@ -5,6 +5,7 @@
import stripe
from django.db import models
+from django.db.models import Q
from django.utils import timezone
from app_users.models import PaymentProvider
@@ -72,7 +73,13 @@ class Subscription(models.Model):
objects = SubscriptionQuerySet.as_manager()
class Meta:
- unique_together = ("payment_provider", "external_id")
+ constraints = [
+ models.UniqueConstraint(
+ fields=["payment_provider", "external_id"],
+ condition=Q(plan__monthly_charge__gt=0),
+ name="unique_provider_and_subscription_id",
+ )
+ ]
indexes = [
models.Index(fields=["plan"]),
]
@@ -85,37 +92,39 @@ def __str__(self):
ret = f"Auto | {ret}"
return ret
- def full_clean(self, *args, **kwargs):
+ def full_clean(
+ self, amount: int = None, charged_amount: int = None, *args, **kwargs
+ ):
+ if self.auto_recharge_enabled:
+ if amount is None:
+ amount = PricingPlan.from_sub(self).credits
+ if charged_amount is None:
+ charged_amount = PricingPlan.from_sub(self).monthly_charge * 100
+ self.ensure_default_auto_recharge_params(
+ amount=amount, charged_amount=charged_amount
+ )
+ return super().full_clean(*args, **kwargs)
+
+ def ensure_default_auto_recharge_params(self, *, amount: int, charged_amount: int):
+ if amount <= 0 or charged_amount <= 0:
+ return
+
if not self.auto_recharge_balance_threshold:
- self.auto_recharge_balance_threshold = (
- self._get_default_auto_recharge_balance_threshold()
+ # 25% of the credits
+ self.auto_recharge_balance_threshold = nearest_choice(
+ settings.AUTO_RECHARGE_BALANCE_THRESHOLD_CHOICES, 0.25 * amount
)
if not self.monthly_spending_budget:
- self.monthly_spending_budget = self._get_default_monthly_spending_budget()
+ # 3x the charged amount
+ self.monthly_spending_budget = 3 * charged_amount / 100 # in dollars
if not self.monthly_spending_notification_threshold:
- self.monthly_spending_notification_threshold = (
- self._get_default_monthly_spending_notification_threshold()
+ # 80% of the monthly budget
+ self.monthly_spending_notification_threshold = int(
+ 0.8 * self.monthly_spending_budget
)
- return super().full_clean(*args, **kwargs)
-
- def _get_default_auto_recharge_balance_threshold(self):
- # 25% of the monthly credit subscription
- threshold = int(PricingPlan.from_sub(self).credits * 0.25)
- return nearest_choice(
- settings.AUTO_RECHARGE_BALANCE_THRESHOLD_CHOICES, threshold
- )
-
- def _get_default_monthly_spending_budget(self):
- # 3x the monthly subscription charge
- return 3 * PricingPlan.from_sub(self).monthly_charge
-
- def _get_default_monthly_spending_notification_threshold(self):
- # 80% of the monthly budget
- return int(0.8 * self._get_default_monthly_spending_budget())
-
@property
def has_user(self) -> bool:
try:
@@ -125,10 +134,26 @@ def has_user(self) -> bool:
else:
return True
+ def is_paid(self) -> bool:
+ return PricingPlan.from_sub(self).monthly_charge > 0 and self.external_id
+
def cancel(self):
+ from payments.webhooks import StripeWebhookHandler
+
+ if not self.is_paid():
+ return
+
match self.payment_provider:
case PaymentProvider.STRIPE:
- stripe.Subscription.cancel(self.external_id)
+ try:
+ stripe.Subscription.cancel(self.external_id)
+ except stripe.error.InvalidRequestError as e:
+ if e.code == "resource_missing":
+ StripeWebhookHandler.handle_subscription_cancelled(
+ self.user.uid
+ )
+ else:
+ raise
case PaymentProvider.PAYPAL:
paypal.Subscription.retrieve(self.external_id).cancel()
case _:
@@ -156,16 +181,6 @@ def get_next_invoice_timestamp(self) -> float | None:
def get_payment_method_summary(self) -> PaymentMethodSummary | None:
match self.payment_provider:
- case PaymentProvider.STRIPE:
- pm = self.stripe_get_default_payment_method()
- if not pm:
- return None
- return PaymentMethodSummary(
- payment_method_type=pm.type,
- card_brand=pm.card and pm.card.brand,
- card_last4=pm.card and pm.card.last4,
- billing_email=(pm.billing_details and pm.billing_details.email),
- )
case PaymentProvider.PAYPAL:
subscription = paypal.Subscription.retrieve(self.external_id)
subscriber = subscription.subscriber
@@ -178,20 +193,43 @@ def get_payment_method_summary(self) -> PaymentMethodSummary | None:
card_last4=source.get("card", {}).get("last_digits"),
billing_email=subscriber.email_address,
)
+ case PaymentProvider.STRIPE:
+ pm = self.stripe_get_default_payment_method()
+ if not pm:
+ # clear the payment provider if the default payment method is missing
+ if self.payment_provider and not self.is_paid():
+ self.payment_provider = None
+ self.save(update_fields=["payment_provider"])
+ return None
+ return PaymentMethodSummary(
+ payment_method_type=pm.type,
+ card_brand=(
+ (pm.type == "card" and pm.card and pm.card.brand) or pm.type
+ ),
+ card_last4=(pm.type == "card" and pm.card and pm.card.last4) or "",
+ billing_email=(pm.billing_details and pm.billing_details.email),
+ )
def stripe_get_default_payment_method(self) -> stripe.PaymentMethod | None:
if self.payment_provider != PaymentProvider.STRIPE:
- raise ValueError("Invalid Payment Provider")
-
- subscription = stripe.Subscription.retrieve(self.external_id)
- if subscription.default_payment_method:
- return stripe.PaymentMethod.retrieve(subscription.default_payment_method)
+ return None
- customer = stripe.Customer.retrieve(subscription.customer)
- if customer.invoice_settings.default_payment_method:
- return stripe.PaymentMethod.retrieve(
- customer.invoice_settings.default_payment_method
+ if self.external_id:
+ subscription = stripe.Subscription.retrieve(
+ self.external_id, expand=["default_payment_method"]
)
+ if subscription.default_payment_method:
+ return subscription.default_payment_method
+
+ customer_id = self.stripe_get_customer_id()
+ customer = stripe.Customer.retrieve(
+ customer_id, expand=["invoice_settings.default_payment_method"]
+ )
+ if (
+ customer.invoice_settings
+ and customer.invoice_settings.default_payment_method
+ ):
+ return customer.invoice_settings.default_payment_method
return None
@@ -213,7 +251,11 @@ def stripe_get_or_create_auto_invoice(
customer=customer_id,
collection_method="charge_automatically",
)
- invoices = [inv for inv in invoices.data if metadata_key in inv.metadata]
+ invoices = [
+ inv
+ for inv in invoices.data
+ if inv.metadata and metadata_key in inv.metadata
+ ]
open_invoice = next((inv for inv in invoices if inv.status == "open"), None)
if open_invoice:
@@ -258,10 +300,11 @@ def stripe_create_auto_invoice(self, *, amount_in_dollars: int, metadata_key: st
return invoice
def stripe_get_customer_id(self) -> str:
- if self.payment_provider == PaymentProvider.STRIPE:
+ if self.payment_provider == PaymentProvider.STRIPE and self.external_id:
subscription = stripe.Subscription.retrieve(self.external_id)
return subscription.customer
- raise ValueError("Invalid Payment Provider")
+ else:
+ return self.user.get_or_create_stripe_customer().id
def stripe_attempt_addon_purchase(self, amount_in_dollars: int) -> bool:
from payments.webhooks import StripeWebhookHandler
@@ -273,6 +316,8 @@ def stripe_attempt_addon_purchase(self, amount_in_dollars: int) -> bool:
if invoice.status != "open":
return False
pm = self.stripe_get_default_payment_method()
+ if not pm:
+ return False
invoice = invoice.pay(payment_method=pm)
if not invoice.paid:
return False
@@ -286,12 +331,6 @@ def get_external_management_url(self) -> str:
from routers.account import account_route
match self.payment_provider:
- case PaymentProvider.STRIPE:
- portal = stripe.billing_portal.Session.create(
- customer=self.stripe_get_customer_id(),
- return_url=get_app_route_url(account_route),
- )
- return portal.url
case PaymentProvider.PAYPAL:
return str(
settings.PAYPAL_WEB_BASE_URL
@@ -300,6 +339,12 @@ def get_external_management_url(self) -> str:
/ "connect"
/ self.external_id
)
+ case PaymentProvider.STRIPE:
+ portal = stripe.billing_portal.Session.create(
+ customer=self.stripe_get_customer_id(),
+ return_url=get_app_route_url(account_route),
+ )
+ return portal.url
case _:
raise NotImplementedError(
f"Can't get management URL for subscription with provider {self.payment_provider}"
@@ -328,6 +373,10 @@ def should_send_monthly_spending_notification(self) -> bool:
)
-def nearest_choice(choices: list[int], value: int) -> int:
- # nearest value in choices that is less than or equal to value
- return min(filter(lambda x: x <= value, choices), key=lambda x: abs(x - value))
+def nearest_choice(choices: list[int], value: float) -> int:
+ # nearest choice that is less than or equal to the value (or the minimum choice if value is the least)
+ return min(
+ filter(lambda x: x <= value, choices),
+ key=lambda x: abs(x - value),
+ default=min(choices),
+ )
diff --git a/payments/webhooks.py b/payments/webhooks.py
index b30cae120..0b822cfe7 100644
--- a/payments/webhooks.py
+++ b/payments/webhooks.py
@@ -1,8 +1,14 @@
+from copy import copy
+
import stripe
from django.db import transaction
from loguru import logger
-from app_users.models import AppUser, PaymentProvider, TransactionReason
+from app_users.models import (
+ AppUser,
+ PaymentProvider,
+ TransactionReason,
+)
from daras_ai_v2 import paypal
from .models import Subscription
from .plans import PricingPlan
@@ -59,7 +65,7 @@ def handle_subscription_updated(cls, pp_sub: paypal.Subscription):
)
return
- _set_user_subscription(
+ set_user_subscription(
provider=cls.PROVIDER,
plan=plan,
uid=pp_sub.custom_id,
@@ -69,8 +75,11 @@ def handle_subscription_updated(cls, pp_sub: paypal.Subscription):
@classmethod
def handle_subscription_cancelled(cls, pp_sub: paypal.Subscription):
assert pp_sub.custom_id, f"PayPal subscription {pp_sub.id} is missing uid"
- _remove_subscription_for_user(
- provider=cls.PROVIDER, uid=pp_sub.custom_id, external_id=pp_sub.id
+ set_user_subscription(
+ uid=pp_sub.custom_id,
+ plan=PricingPlan.STARTER,
+ provider=None,
+ external_id=None,
)
@@ -79,6 +88,8 @@ class StripeWebhookHandler:
@classmethod
def handle_invoice_paid(cls, uid: str, invoice: stripe.Invoice):
+ from app_users.tasks import save_stripe_default_payment_method
+
kwargs = {}
if invoice.subscription:
kwargs["plan"] = PricingPlan.get_by_key(
@@ -97,32 +108,48 @@ def handle_invoice_paid(cls, uid: str, invoice: stripe.Invoice):
reason = TransactionReason.AUTO_RECHARGE
else:
reason = TransactionReason.ADDON
+
+ amount = invoice.lines.data[0].quantity
+ charged_amount = invoice.lines.data[0].amount
add_balance_for_payment(
uid=uid,
- amount=invoice.lines.data[0].quantity,
+ amount=amount,
invoice_id=invoice.id,
payment_provider=cls.PROVIDER,
- charged_amount=invoice.lines.data[0].amount,
+ charged_amount=charged_amount,
reason=reason,
**kwargs,
)
+ save_stripe_default_payment_method.delay(
+ payment_intent_id=invoice.payment_intent,
+ uid=uid,
+ amount=amount,
+ charged_amount=charged_amount,
+ reason=reason,
+ )
+
@classmethod
def handle_checkout_session_completed(cls, uid: str, session_data):
- if setup_intent_id := session_data.get("setup_intent") is None:
+ setup_intent_id = session_data.get("setup_intent")
+ if not setup_intent_id:
# not a setup mode checkout -- do nothing
return
- setup_intent = stripe.SetupIntent.retrieve(setup_intent_id)
-
- # subscription_id was passed to metadata when creating the session
- sub_id = setup_intent.metadata["subscription_id"]
- assert (
- sub_id
- ), f"subscription_id is missing in setup_intent metadata {setup_intent}"
- stripe.Subscription.modify(
- sub_id, default_payment_method=setup_intent.payment_method
- )
+ setup_intent = stripe.SetupIntent.retrieve(setup_intent_id)
+ if sub_id := setup_intent.metadata.get("subscription_id"):
+ # subscription_id was passed to metadata when creating the session
+ stripe.Subscription.modify(
+ sub_id, default_payment_method=setup_intent.payment_method
+ )
+ elif customer_id := session_data.get("customer"):
+ # no subscription_id, so update the customer's default payment method instead
+ stripe.Customer.modify(
+ customer_id,
+ invoice_settings=dict(
+ default_payment_method=setup_intent.payment_method
+ ),
+ )
@classmethod
def handle_subscription_updated(cls, uid: str, stripe_sub: stripe.Subscription):
@@ -146,7 +173,7 @@ def handle_subscription_updated(cls, uid: str, stripe_sub: stripe.Subscription):
)
return
- _set_user_subscription(
+ set_user_subscription(
provider=cls.PROVIDER,
plan=plan,
uid=uid,
@@ -154,10 +181,12 @@ def handle_subscription_updated(cls, uid: str, stripe_sub: stripe.Subscription):
)
@classmethod
- def handle_subscription_cancelled(cls, uid: str, stripe_sub):
- logger.info(f"Stripe subscription cancelled: {stripe_sub.id}")
- _remove_subscription_for_user(
- provider=cls.PROVIDER, uid=uid, external_id=stripe_sub.id
+ def handle_subscription_cancelled(cls, uid: str):
+ set_user_subscription(
+ uid=uid,
+ plan=PricingPlan.STARTER,
+ provider=PaymentProvider.STRIPE,
+ external_id=None,
)
@@ -190,44 +219,37 @@ def add_balance_for_payment(
send_monthly_spending_notification_email.delay(user.id)
-def _set_user_subscription(
- *, provider: PaymentProvider, plan: PricingPlan, uid: str, external_id: str
-):
+def set_user_subscription(
+ *,
+ uid: str,
+ plan: PricingPlan,
+ provider: PaymentProvider | None,
+ external_id: str | None,
+ amount: int = None,
+ charged_amount: int = None,
+) -> Subscription:
with transaction.atomic():
- subscription, created = Subscription.objects.get_or_create(
- payment_provider=provider,
- external_id=external_id,
- defaults=dict(plan=plan.db_value),
- )
- subscription.plan = plan.db_value
- subscription.full_clean()
- subscription.save()
-
user = AppUser.objects.get_or_create_from_uid(uid)[0]
- existing = user.subscription
-
- user.subscription = subscription
- user.save(update_fields=["subscription"])
- if not existing:
- return
+ old_sub = user.subscription
+ if old_sub:
+ new_sub = copy(old_sub)
+ else:
+ old_sub = None
+ new_sub = Subscription()
- # cancel existing subscription if it's not the same as the new one
- if existing.external_id != external_id:
- existing.cancel()
+ new_sub.plan = plan.db_value
+ new_sub.payment_provider = provider
+ new_sub.external_id = external_id
+ new_sub.full_clean(amount=amount, charged_amount=charged_amount)
+ new_sub.save()
- # delete old db record if it exists
- if existing.id != subscription.id:
- existing.delete()
+ if not old_sub:
+ user.subscription = new_sub
+ user.save(update_fields=["subscription"])
+ # cancel previous subscription if it's not the same as the new one
+ if old_sub and old_sub.external_id != external_id:
+ old_sub.cancel()
-def _remove_subscription_for_user(
- *, uid: str, provider: PaymentProvider, external_id: str
-):
- AppUser.objects.filter(
- uid=uid,
- subscription__payment_provider=provider,
- subscription__external_id=external_id,
- ).update(
- subscription=None,
- )
+ return new_sub
diff --git a/routers/paypal.py b/routers/paypal.py
index 797812fae..933a2a9a7 100644
--- a/routers/paypal.py
+++ b/routers/paypal.py
@@ -126,7 +126,7 @@ def create_subscription(request: Request, payload: dict = fastapi_request_json):
if plan.deprecated:
return JSONResponse({"error": "Deprecated plan"}, status_code=400)
- if request.user.subscription:
+ if request.user.subscription and request.user.subscription.is_paid():
return JSONResponse(
{"error": "User already has an active subscription"}, status_code=400
)
diff --git a/routers/stripe.py b/routers/stripe.py
index 3d481d47b..04538bd55 100644
--- a/routers/stripe.py
+++ b/routers/stripe.py
@@ -46,6 +46,6 @@ def webhook_received(request: Request, payload: bytes = fastapi_request_body):
case "customer.subscription.created" | "customer.subscription.updated":
StripeWebhookHandler.handle_subscription_updated(uid, data)
case "customer.subscription.deleted":
- StripeWebhookHandler.handle_subscription_cancelled(uid, data)
+ StripeWebhookHandler.handle_subscription_cancelled(uid)
return JSONResponse({"status": "success"})
diff --git a/scripts/migrate_existing_subscriptions.py b/scripts/migrate_existing_subscriptions.py
index 967b5bf73..ceb7dc224 100644
--- a/scripts/migrate_existing_subscriptions.py
+++ b/scripts/migrate_existing_subscriptions.py
@@ -29,7 +29,7 @@
# "quantity": 1000, # number of credits (set by html)
"adjustable_quantity": {
"enabled": True,
- "maximum": 50_000,
+ "maximum": 100_000,
"minimum": 1_000,
},
},
diff --git a/tests/test_checkout.py b/tests/test_checkout.py
index 21a9a3291..a392c2cbc 100644
--- a/tests/test_checkout.py
+++ b/tests/test_checkout.py
@@ -3,7 +3,7 @@
from app_users.models import AppUser
from daras_ai_v2 import settings
-from daras_ai_v2.billing import stripe_subscription_checkout_redirect
+from daras_ai_v2.billing import stripe_subscription_create
from gooey_gui import RedirectException
from payments.plans import PricingPlan
from server import app
@@ -20,4 +20,4 @@ def test_create_checkout_session(
return
with pytest.raises(RedirectException):
- stripe_subscription_checkout_redirect(force_authentication, plan)
+ stripe_subscription_create(force_authentication, plan)