diff --git a/app_users/admin.py b/app_users/admin.py index da50b57e9..caa61d223 100644 --- a/app_users/admin.py +++ b/app_users/admin.py @@ -192,19 +192,24 @@ class AppUserTransactionAdmin(admin.ModelAdmin): "invoice_id", "user", "amount", + "dollar_amount", "end_balance", "payment_provider", - "dollar_amount", + "reason", + "plan", "created_at", ] - readonly_fields = ["created_at"] + readonly_fields = ["view_payment_provider_url", "created_at"] list_filter = [ - "created_at", + "reason", ("payment_provider", admin.EmptyFieldListFilter), "payment_provider", + "plan", + "created_at", ] inlines = [SavedRunInline] ordering = ["-created_at"] + search_fields = ["invoice_id"] @admin.display(description="Charged Amount") def dollar_amount(self, obj: models.AppUserTransaction): @@ -212,6 +217,14 @@ def dollar_amount(self, obj: models.AppUserTransaction): return return f"${obj.charged_amount / 100}" + @admin.display(description="Payment Provider URL") + def view_payment_provider_url(self, txn: models.AppUserTransaction): + url = txn.payment_provider_url() + if url: + return open_in_new_tab(url, label=url) + else: + raise txn.DoesNotExist + @admin.register(LogEntry) class LogEntryAdmin(admin.ModelAdmin): diff --git a/app_users/migrations/0018_appusertransaction_plan_appusertransaction_reason.py b/app_users/migrations/0018_appusertransaction_plan_appusertransaction_reason.py new file mode 100644 index 000000000..2891795fd --- /dev/null +++ b/app_users/migrations/0018_appusertransaction_plan_appusertransaction_reason.py @@ -0,0 +1,59 @@ +# Generated by Django 4.2.7 on 2024-07-14 20:51 + +from django.db import migrations, models + + +def forwards_func(apps, schema_editor): + from payments.plans import PricingPlan + from app_users.models import TransactionReason + + # We get the model from the versioned app registry; + # if we directly import it, it'll be the wrong version + AppUserTransaction = apps.get_model("app_users", "AppUserTransaction") + db_alias = schema_editor.connection.alias + objects = AppUserTransaction.objects.using(db_alias) + + for transaction in objects.all(): + if transaction.amount <= 0: + transaction.reason = TransactionReason.DEDUCT + else: + # For old transactions, we didn't have a subscription field. + # It just so happened that all monthly subscriptions we offered had + # different amounts from the one-time purchases. + # This uses that heuristic to determine whether a transaction + # was a subscription payment or a one-time purchase. + transaction.reason = TransactionReason.ADDON + for plan in PricingPlan: + if ( + transaction.amount == plan.credits + and transaction.charged_amount == plan.monthly_charge * 100 + ): + transaction.plan = plan.db_value + transaction.reason = TransactionReason.SUBSCRIBE + transaction.save(update_fields=["reason", "plan"]) + + +class Migration(migrations.Migration): + + dependencies = [ + ('app_users', '0017_alter_appuser_subscription'), + ] + + operations = [ + migrations.AddField( + model_name='appusertransaction', + name='plan', + field=models.IntegerField(blank=True, choices=[(1, 'Basic Plan'), (2, 'Premium Plan'), (3, 'Starter'), (4, 'Creator'), (5, 'Business'), (6, 'Enterprise / Agency')], default=None, help_text="User's plan at the time of this transaction.", null=True), + ), + migrations.AddField( + model_name='appusertransaction', + name='reason', + field=models.IntegerField(choices=[(1, 'Deduct'), (2, 'Addon'), (3, 'Subscribe'), (4, 'Sub-Create'), (5, 'Sub-Cycle'), (6, 'Sub-Update'), (7, 'Auto-Recharge')], default=0, help_text='The reason for this transaction.

Deduct: Credits deducted due to a run.
Addon: User purchased an add-on.
Subscribe: Applies to subscriptions where no distinction was made between create, update and cycle.
Sub-Create: A subscription was created.
Sub-Cycle: A subscription advanced into a new period.
Sub-Update: A subscription was updated.
Auto-Recharge: Credits auto-recharged due to low balance.'), + ), + migrations.RunPython(forwards_func, migrations.RunPython.noop), + migrations.AlterField( + model_name='appusertransaction', + name='reason', + field=models.IntegerField(choices=[(1, 'Deduct'), (2, 'Addon'), (3, 'Subscribe'), (4, 'Sub-Create'), (5, 'Sub-Cycle'), (6, 'Sub-Update'), (7, 'Auto-Recharge')], help_text='The reason for this transaction.

Deduct: Credits deducted due to a run.
Addon: User purchased an add-on.
Subscribe: Applies to subscriptions where no distinction was made between create, update and cycle.
Sub-Create: A subscription was created.
Sub-Cycle: A subscription advanced into a new period.
Sub-Update: A subscription was updated.
Auto-Recharge: Credits auto-recharged due to low balance.'), + ), + ] diff --git a/app_users/models.py b/app_users/models.py index f4b29f490..1e1016520 100644 --- a/app_users/models.py +++ b/app_users/models.py @@ -4,6 +4,7 @@ from django.db.models import Sum from django.utils import timezone from firebase_admin import auth +from furl import furl from phonenumber_field.modelfields import PhoneNumberField from bots.custom_fields import CustomURLField, StrippedTextField @@ -172,6 +173,7 @@ def add_balance( user: AppUser = AppUser.objects.select_for_update().get(pk=self.pk) user.balance += amount user.save(update_fields=["balance"]) + kwargs.setdefault("plan", user.subscription and user.subscription.plan) return AppUserTransaction.objects.create( user=self, invoice_id=invoice_id, @@ -273,6 +275,18 @@ def get_dollars_spent_this_month(self) -> float: return (cents_spent or 0) / 100 +class TransactionReason(models.IntegerChoices): + DEDUCT = 1, "Deduct" + ADDON = 2, "Addon" + + SUBSCRIBE = 3, "Subscribe" + SUBSCRIPTION_CREATE = 4, "Sub-Create" + SUBSCRIPTION_CYCLE = 5, "Sub-Cycle" + SUBSCRIPTION_UPDATE = 6, "Sub-Update" + + AUTO_RECHARGE = 7, "Auto-Recharge" + + class AppUserTransaction(models.Model): user = models.ForeignKey( "AppUser", on_delete=models.CASCADE, related_name="transactions" @@ -307,6 +321,25 @@ class AppUserTransaction(models.Model): default=0, ) + reason = models.IntegerField( + choices=TransactionReason.choices, + help_text="The reason for this transaction.

" + f"{TransactionReason.DEDUCT.label}: Credits deducted due to a run.
" + f"{TransactionReason.ADDON.label}: User purchased an add-on.
" + f"{TransactionReason.SUBSCRIBE.label}: Applies to subscriptions where no distinction was made between create, update and cycle.
" + f"{TransactionReason.SUBSCRIPTION_CREATE.label}: A subscription was created.
" + f"{TransactionReason.SUBSCRIPTION_CYCLE.label}: A subscription advanced into a new period.
" + f"{TransactionReason.SUBSCRIPTION_UPDATE.label}: A subscription was updated.
" + f"{TransactionReason.AUTO_RECHARGE.label}: Credits auto-recharged due to low balance.", + ) + plan = models.IntegerField( + choices=PricingPlan.db_choices(), + help_text="User's plan at the time of this transaction.", + null=True, + blank=True, + default=None, + ) + created_at = models.DateTimeField(editable=False, blank=True, default=timezone.now) class Meta: @@ -320,32 +353,41 @@ class Meta: def __str__(self): return f"{self.invoice_id} ({self.amount})" - def get_subscription_plan(self) -> PricingPlan | None: - """ - It just so happened that all monthly subscriptions we offered had - different amounts from the one-time purchases. - This uses that heuristic to determine whether a transaction - was a subscription payment or a one-time purchase. - - TODO: Implement this more robustly - """ - if self.amount <= 0: - # credits deducted - return None - - for plan in PricingPlan: - if ( - self.amount == plan.credits - and self.charged_amount == plan.monthly_charge * 100 + def save(self, *args, **kwargs): + if self.reason is None: + if self.amount <= 0: + self.reason = TransactionReason.DEDUCT + else: + self.reason = TransactionReason.ADDON + super().save(*args, **kwargs) + + def reason_note(self) -> str: + match self.reason: + case ( + TransactionReason.SUBSCRIPTION_CREATE + | TransactionReason.SUBSCRIPTION_CYCLE + | TransactionReason.SUBSCRIPTION_UPDATE + | TransactionReason.SUBSCRIBE ): - return plan - - return None - - def note(self) -> str: - if self.amount <= 0: - return "" - elif plan := self.get_subscription_plan(): - return f"Subscription payment: {plan.title} (+{self.amount:,} credits)" - else: - return f"Addon purchase (+{self.amount:,} credits)" + ret = "Subscription payment" + if self.plan: + ret += f": {PricingPlan.from_db_value(self.plan).title}" + return ret + case TransactionReason.AUTO_RECHARGE: + return "Auto recharge" + case TransactionReason.ADDON: + return "Addon purchase" + case TransactionReason.DEDUCT: + return "Run deduction" + + def payment_provider_url(self) -> str | None: + match self.payment_provider: + case PaymentProvider.STRIPE: + return str( + furl("https://dashboard.stripe.com/invoices/") / self.invoice_id + ) + case PaymentProvider.PAYPAL: + return str( + furl("https://www.paypal.com/unifiedtransactions/details/payment/") + / self.invoice_id + ) diff --git a/bots/admin.py b/bots/admin.py index 82da0aab2..e154f03b2 100644 --- a/bots/admin.py +++ b/bots/admin.py @@ -31,7 +31,7 @@ Workflow, ) from bots.tasks import create_personal_channels_for_all_members -from celeryapp.tasks import gui_runner +from celeryapp.tasks import runner_task from daras_ai_v2.fastapi_tricks import get_route_url from gooeysite.custom_actions import export_to_excel, export_to_csv from gooeysite.custom_filters import ( diff --git a/bots/migrations/0077_savedrun_error_code_savedrun_error_type_and_more.py b/bots/migrations/0077_savedrun_error_code_savedrun_error_type_and_more.py new file mode 100644 index 000000000..44033b955 --- /dev/null +++ b/bots/migrations/0077_savedrun_error_code_savedrun_error_type_and_more.py @@ -0,0 +1,28 @@ +# Generated by Django 4.2.7 on 2024-07-12 19:30 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('bots', '0076_alter_workflowmetadata_default_image_and_more'), + ] + + operations = [ + migrations.AddField( + model_name='savedrun', + name='error_code', + field=models.IntegerField(blank=True, default=None, help_text='The HTTP status code of the error. If this is not set, 500 is assumed.', null=True), + ), + migrations.AddField( + model_name='savedrun', + name='error_type', + field=models.TextField(blank=True, default='', help_text='The exception type'), + ), + migrations.AlterField( + model_name='savedrun', + name='error_msg', + field=models.TextField(blank=True, default='', help_text='The error message. If this is not set, the run is deemed successful.'), + ), + ] diff --git a/bots/models.py b/bots/models.py index 6ff6674ff..757766a4c 100644 --- a/bots/models.py +++ b/bots/models.py @@ -212,10 +212,24 @@ class SavedRun(models.Model): state = models.JSONField(default=dict, blank=True, encoder=PostgresJSONEncoder) - error_msg = models.TextField(default="", blank=True) + error_msg = models.TextField( + default="", + blank=True, + help_text="The error message. If this is not set, the run is deemed successful.", + ) run_time = models.DurationField(default=datetime.timedelta, blank=True) run_status = models.TextField(default="", blank=True) + error_code = models.IntegerField( + null=True, + default=None, + blank=True, + help_text="The HTTP status code of the error. If this is not set, 500 is assumed.", + ) + error_type = models.TextField( + default="", blank=True, help_text="The exception type" + ) + hidden = models.BooleanField(default=False) is_flagged = models.BooleanField(default=False) @@ -282,9 +296,12 @@ def __str__(self): def parent_published_run(self) -> typing.Optional["PublishedRun"]: return self.parent_version and self.parent_version.published_run - def get_app_url(self): + def get_app_url(self, query_params: dict = None): return Workflow(self.workflow).page_cls.app_url( - example_id=self.example_id, run_id=self.run_id, uid=self.uid + example_id=self.example_id, + run_id=self.run_id, + uid=self.uid, + query_params=query_params, ) def to_dict(self) -> dict: @@ -1624,9 +1641,9 @@ def duplicate( visibility=visibility, ) - def get_app_url(self): + def get_app_url(self, query_params: dict = None): return Workflow(self.workflow).page_cls.app_url( - example_id=self.published_run_id + example_id=self.published_run_id, query_params=query_params ) def add_version( diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index b2a7b4327..2e30e4379 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -14,32 +14,34 @@ import gooey_ui as st from app_users.models import AppUser, AppUserTransaction from bots.admin_links import change_obj_url -from bots.models import SavedRun, Platform +from bots.models import SavedRun, Platform, Workflow from celeryapp.celeryconfig import app from daras_ai.image_input import truncate_text_words from daras_ai_v2 import settings -from daras_ai_v2.auto_recharge import auto_recharge_user from daras_ai_v2.base import StateKeys, BasePage from daras_ai_v2.exceptions import UserError -from daras_ai_v2.redis_cache import redis_lock from daras_ai_v2.send_email import send_email_via_postmark, send_low_balance_email from daras_ai_v2.settings import templates from gooey_ui.pubsub import realtime_push from gooey_ui.state import set_query_params from gooeysite.bg_db_conn import db_middleware +from payments.auto_recharge import ( + should_attempt_auto_recharge, + run_auto_recharge_gracefully, +) DEFAULT_RUN_STATUS = "Running..." @app.task -def gui_runner( +def runner_task( *, page_cls: typing.Type[BasePage], user_id: int, run_id: str, uid: str, channel: str, -): +) -> int: start_time = time() error_msg = None @@ -89,34 +91,50 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False save_on_step() for val in page.main(sr, st.session_state): save_on_step(val) + # render errors nicely except Exception as e: - if isinstance(e, HTTPException) and e.status_code == 402: - error_msg = page.generate_credit_error_message(run_id, uid) - try: - raise UserError(error_msg) from e - except UserError as e: - sentry_sdk.capture_exception(e, level=e.sentry_level) + if isinstance(e, UserError): + sentry_level = e.sentry_level else: - if isinstance(e, UserError): - sentry_level = e.sentry_level - else: - sentry_level = "error" - traceback.print_exc() - sentry_sdk.capture_exception(e, level=sentry_level) - error_msg = err_msg_for_exc(e) + sentry_level = "error" + traceback.print_exc() + sentry_sdk.capture_exception(e, level=sentry_level) + error_msg = err_msg_for_exc(e) + sr.error_type = type(e).__qualname__ + sr.error_code = getattr(e, "status_code", None) + # run completed successfully, deduct credits else: sr.transaction, sr.price = page.deduct_credits(st.session_state) + + # save everything, mark run as completed finally: save_on_step(done=True) - if not sr.is_api_call: - send_email_on_completion(page, sr) - run_low_balance_email_check(user) + + return sr.id + + +@app.task +def post_runner_tasks(saved_run_id: int): + sr = SavedRun.objects.get(id=saved_run_id) + user = AppUser.objects.get(uid=sr.uid) + + if not sr.is_api_call: + send_email_on_completion(sr) + + if should_attempt_auto_recharge(user): + run_auto_recharge_gracefully(user) + + run_low_balance_email_check(user) def err_msg_for_exc(e: Exception): - if isinstance(e, requests.HTTPError): + if isinstance(e, UserError): + return e.message + elif isinstance(e, HTTPException): + return f"(HTTP {e.status_code}) {e.detail})" + elif isinstance(e, requests.HTTPError): response: requests.Response = e.response try: err_body = response.json() @@ -133,10 +151,6 @@ def err_msg_for_exc(e: Exception): return f"(GPU) {err_type}: {err_str}" err_str = str(err_body) return f"(HTTP {response.status_code}) {html.escape(err_str[:1000])}" - elif isinstance(e, HTTPException): - return f"(HTTP {e.status_code}) {e.detail})" - elif isinstance(e, UserError): - return e.message else: return f"{type(e).__name__}: {e}" @@ -177,7 +191,7 @@ def run_low_balance_email_check(user: AppUser): user.save(update_fields=["low_balance_email_sent_at"]) -def send_email_on_completion(page: BasePage, sr: SavedRun): +def send_email_on_completion(sr: SavedRun): run_time_sec = sr.run_time.total_seconds() if ( run_time_sec <= settings.SEND_RUN_EMAIL_AFTER_SEC @@ -189,9 +203,16 @@ def send_email_on_completion(page: BasePage, sr: SavedRun): ) if not to_address: return - prompt = (page.preview_input(sr.state) or "").strip() - title = (sr.state.get("__title") or page.title).strip() - subject = f"🌻 “{truncate_text_words(prompt, maxlen=50)}” {title} is done" + + workflow = Workflow(sr.workflow) + page_cls = workflow.page_cls + prompt = (page_cls.preview_input(sr.state) or "").strip().replace("\n", " ") + recipe_title = page_cls.get_recipe_title() + + subject = ( + f"🌻 “{truncate_text_words(prompt, maxlen=50) or 'Run'}” {recipe_title} is done" + ) + send_email_via_postmark( from_address=settings.SUPPORT_EMAIL, to_address=to_address, @@ -200,7 +221,7 @@ def send_email_on_completion(page: BasePage, sr: SavedRun): run_time_sec=round(run_time_sec), app_url=sr.get_app_url(), prompt=prompt, - title=title, + recipe_title=recipe_title, ), message_stream="gooey-ai-workflows", ) @@ -221,11 +242,3 @@ def send_integration_attempt_email(*, user_id: int, platform: Platform, run_url: subject=f"{user.display_name} Attempted to Connect to {platform.label}", html_body=html_body, ) - - -@app.task -def auto_recharge(*, user_id: int): - redis_lock_key = f"gooey/auto_recharge/{user_id}" - with redis_lock(redis_lock_key): - user = AppUser.objects.get(id=user_id) - auto_recharge_user(user) diff --git a/conftest.py b/conftest.py index 96fb837f1..a38c6a11a 100644 --- a/conftest.py +++ b/conftest.py @@ -51,16 +51,17 @@ def force_authentication(): @pytest.fixture -def mock_gui_runner(): +def mock_celery_tasks(): with ( - patch("celeryapp.tasks.gui_runner", _mock_gui_runner), + patch("celeryapp.tasks.runner_task", _mock_runner_task), + patch("celeryapp.tasks.post_runner_tasks", _mock_post_runner_tasks), patch("daras_ai_v2.bots.realtime_subscribe", _mock_realtime_subscribe), ): yield @app.task -def _mock_gui_runner( +def _mock_runner_task( *, page_cls: typing.Type[BasePage], run_id: str, uid: str, **kwargs ): sr = page_cls.run_doc_sr(run_id, uid) @@ -70,6 +71,11 @@ def _mock_gui_runner( _mock_realtime_push(channel, sr.to_dict()) +@app.task +def _mock_post_runner_tasks(*args, **kwargs): + pass + + def _mock_realtime_push(channel, value): redis_qs[channel].put(value) diff --git a/daras_ai_v2/auto_recharge.py b/daras_ai_v2/auto_recharge.py deleted file mode 100644 index 15c47e42f..000000000 --- a/daras_ai_v2/auto_recharge.py +++ /dev/null @@ -1,64 +0,0 @@ -from loguru import logger - -from app_users.models import AppUser, PaymentProvider -from payments.tasks import send_email_budget_reached, send_email_auto_recharge_failed - - -def auto_recharge_user(user: AppUser): - if not user_should_auto_recharge(user): - logger.info(f"User doesn't need to auto-recharge: {user=}") - return - - dollars_spent = user.get_dollars_spent_this_month() - if ( - dollars_spent + user.subscription.auto_recharge_topup_amount - > user.subscription.monthly_spending_budget - ): - if not user.subscription.has_sent_monthly_budget_email_this_month(): - send_email_budget_reached.delay(user.id) - logger.info(f"User has reached the monthly budget: {user=}, {dollars_spent=}") - return - - match user.subscription.payment_provider: - case PaymentProvider.STRIPE: - customer = user.search_stripe_customer() - if not customer: - logger.error(f"User doesn't have a stripe customer: {user=}") - return - - try: - invoice = user.subscription.stripe_get_or_create_auto_invoice( - amount_in_dollars=user.subscription.auto_recharge_topup_amount, - metadata_key="auto_recharge", - ) - - if invoice.status == "open": - pm = user.subscription.stripe_get_default_payment_method() - invoice.pay(payment_method=pm) - logger.info( - f"Payment attempted for auto recharge invoice: {user=}, {invoice=}" - ) - elif invoice.status == "paid": - logger.info( - f"Auto recharge invoice already paid recently: {user=}, {invoice=}" - ) - except Exception as e: - logger.error( - f"Error while auto-recharging user: {user=}, {e=}, {invoice=}" - ) - send_email_auto_recharge_failed.delay(user.id) - - case PaymentProvider.PAYPAL: - logger.error(f"Auto-recharge not supported for PayPal: {user=}") - return - - -def user_should_auto_recharge(user: AppUser): - """ - whether an auto recharge should be attempted for the user - """ - return ( - user.subscription - and user.subscription.auto_recharge_enabled - and user.balance < user.subscription.auto_recharge_balance_threshold - ) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index cfb731d4c..a447b4252 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -38,7 +38,6 @@ from daras_ai.text_format import format_number_with_suffix from daras_ai_v2 import settings, urls from daras_ai_v2.api_examples_widget import api_example_generator -from daras_ai_v2.auto_recharge import user_should_auto_recharge from daras_ai_v2.breadcrumbs import render_breadcrumbs, get_title_breadcrumbs from daras_ai_v2.copy_to_clipboard_button_widget import ( copy_to_clipboard_button, @@ -49,6 +48,7 @@ from daras_ai_v2.db import ( ANONYMOUS_USER_COOKIE, ) +from daras_ai_v2.exceptions import InsufficientCredits from daras_ai_v2.fastapi_tricks import get_route_path from daras_ai_v2.grid_layout_widget import grid_layout from daras_ai_v2.html_spinner_widget import html_spinner @@ -83,6 +83,10 @@ from gooey_ui.components.modal import Modal from gooey_ui.components.pills import pill from gooey_ui.pubsub import realtime_pull +from payments.auto_recharge import ( + should_attempt_auto_recharge, + run_auto_recharge_gracefully, +) from routers.account import AccountTabs from routers.root import RecipeTabs @@ -141,7 +145,6 @@ class BasePage: class RequestModel(BaseModel): functions: list[RecipeFunction] | None = Field( - None, title="🧩 Functions", ) variables: dict[str, typing.Any] = Field( @@ -1411,6 +1414,8 @@ def render_usage_guide(self): raise NotImplementedError def main(self, sr: SavedRun, state: dict) -> typing.Iterator[str | None]: + yield from self.ensure_credits_and_auto_recharge(sr, state) + yield from call_recipe_functions( saved_run=sr, current_user=self.request.user, @@ -1437,15 +1442,12 @@ def run(self, state: dict) -> typing.Iterator[str | None]: response = self.ResponseModel.construct() # run the recipe - gen = self.run_v2(request, response) - while True: - try: - val = next(gen) - except StopIteration: - break - finally: + try: + for val in self.run_v2(request, response): state.update(response.dict(exclude_unset=True)) - yield val + yield val + finally: + state.update(response.dict(exclude_unset=True)) # validate the response if successful self.ResponseModel.validate(response) @@ -1595,8 +1597,6 @@ def estimate_run_duration(self) -> int | None: pass def on_submit(self): - from celeryapp.tasks import auto_recharge - try: sr = self.create_new_run(enable_rate_limits=True) except ValidationError as e: @@ -1608,15 +1608,7 @@ def on_submit(self): st.session_state[StateKeys.error_msg] = e.detail.get("error", "") return - if user_should_auto_recharge(self.request.user): - auto_recharge.delay(user_id=self.request.user.id) - - if settings.CREDITS_TO_DEDUCT_PER_RUN and not self.check_credits(): - sr.run_status = "" - sr.error_msg = self.generate_credit_error_message(sr.run_id, sr.uid) - sr.save(update_fields=["run_status", "error_msg"]) - else: - self.call_runner_task(sr) + self.call_runner_task(sr) raise RedirectException(self.app_url(run_id=sr.run_id, uid=sr.uid)) @@ -1689,15 +1681,19 @@ def dump_state_to_sr(self, state: dict, sr: SavedRun): ) def call_runner_task(self, sr: SavedRun): - from celeryapp.tasks import gui_runner - - return gui_runner.delay( - page_cls=self.__class__, - user_id=self.request.user.id, - run_id=sr.run_id, - uid=sr.uid, - channel=self.realtime_channel_name(sr.run_id, sr.uid), + from celeryapp.tasks import runner_task, post_runner_tasks + + chain = ( + runner_task.s( + page_cls=self.__class__, + user_id=self.request.user.id, + run_id=sr.run_id, + uid=sr.uid, + channel=self.realtime_channel_name(sr.run_id, sr.uid), + ) + | post_runner_tasks.s() ) + return chain.apply_async() @classmethod def realtime_channel_name(cls, run_id, uid): @@ -2073,10 +2069,27 @@ def run_as_api_tab(self): manage_api_keys(self.request.user) - def check_credits(self) -> bool: + def ensure_credits_and_auto_recharge(self, sr: SavedRun, state: dict): + if not settings.CREDITS_TO_DEDUCT_PER_RUN: + return assert self.request, "request must be set to check credits" assert self.request.user, "request.user must be set to check credits" - return self.request.user.balance >= self.get_price_roundoff(st.session_state) + + user = self.request.user + price = self.get_price_roundoff(state) + + if user.balance >= price: + return + + if should_attempt_auto_recharge(user): + yield "Low balance detected. Recharging..." + run_auto_recharge_gracefully(user) + user.refresh_from_db() + + if user.balance >= price: + return + + raise InsufficientCredits(self.request.user, sr) def deduct_credits(self, state: dict) -> tuple[AppUserTransaction, int]: assert ( diff --git a/daras_ai_v2/billing.py b/daras_ai_v2/billing.py index 1e9b2f764..709fde926 100644 --- a/daras_ai_v2/billing.py +++ b/daras_ai_v2/billing.py @@ -1,3 +1,5 @@ +from typing import Literal + import stripe from django.core.exceptions import ValidationError @@ -13,10 +15,14 @@ from gooey_ui.components.pills import pill from payments.models import PaymentMethodSummary from payments.plans import PricingPlan +from scripts.migrate_existing_subscriptions import available_subscriptions rounded_border = "w-100 border shadow-sm rounded py-4 px-3" +PlanActionLabel = Literal["Upgrade", "Downgrade", "Contact Us", "Your Plan"] + + def billing_page(user: AppUser): render_payments_setup() @@ -27,14 +33,15 @@ def billing_page(user: AppUser): render_credit_balance(user) with st.div(className="my-5"): - render_all_plans(user) + selected_payment_provider = render_all_plans(user) + + with st.div(className="my-5"): + render_addon_section(user, selected_payment_provider) if user.subscription and user.subscription.payment_provider: if user.subscription.payment_provider == PaymentProvider.STRIPE: with st.div(className="my-5"): render_auto_recharge_section(user) - with st.div(className="my-5"): - render_addon_section(user) with st.div(className="my-5"): render_payment_information(user) @@ -115,7 +122,7 @@ def render_credit_balance(user: AppUser): ) -def render_all_plans(user: AppUser): +def render_all_plans(user: AppUser) -> PaymentProvider: current_plan = ( PricingPlan.from_sub(user.subscription) if user.subscription @@ -126,11 +133,11 @@ def render_all_plans(user: AppUser): st.write("## All Plans") plans_div = st.div(className="mb-1") - if user.subscription: - payment_provider = None + if user.subscription and user.subscription.payment_provider: + selected_payment_provider = None else: with st.div(): - payment_provider = PaymentProvider[ + selected_payment_provider = PaymentProvider[ payment_provider_radio() or PaymentProvider.STRIPE.name ] @@ -144,7 +151,9 @@ def _render_plan(plan: PricingPlan): className=f"{rounded_border} flex-grow-1 d-flex flex-column p-3 mb-2 {extra_class}" ): _render_plan_details(plan) - _render_plan_action_button(user, plan, current_plan, payment_provider) + _render_plan_action_button( + user, plan, current_plan, selected_payment_provider + ) with plans_div: grid_layout(4, all_plans, _render_plan, separator=False) @@ -152,6 +161,8 @@ def _render_plan(plan: PricingPlan): with st.div(className="my-2 d-flex justify-content-center"): st.caption(f"**[See all features & benefits]({settings.PRICING_DETAILS_URL})**") + return selected_payment_provider + def _render_plan_details(plan: PricingPlan): with st.div(className="flex-grow-1"): @@ -188,33 +199,47 @@ def _render_plan_action_button( className=btn_classes + " btn btn-theme btn-primary", ): st.html("Contact Us") - elif current_plan is not PricingPlan.ENTERPRISE: - update_subscription_button( - user=user, - plan=plan, - current_plan=current_plan, - className=btn_classes, - payment_provider=payment_provider, - ) + elif user.subscription and not user.subscription.payment_provider: + # don't show upgrade/downgrade buttons for enterprise customers + # assumption: anyone without a payment provider attached is admin/enterprise + return + else: + if plan.credits > current_plan.credits: + label, btn_type = ("Upgrade", "primary") + else: + label, btn_type = ("Downgrade", "secondary") + + if user.subscription and user.subscription.payment_provider: + # subscription exists, show upgrade/downgrade button + _render_update_subscription_button( + label, + user=user, + current_plan=current_plan, + plan=plan, + className=f"{btn_classes} btn btn-theme btn-{btn_type}", + ) + else: + assert payment_provider is not None # for sanity + _render_create_subscription_button( + label, + btn_type=btn_type, + user=user, + plan=plan, + payment_provider=payment_provider, + ) -def update_subscription_button( +def _render_create_subscription_button( + label: PlanActionLabel, *, + btn_type: str, user: AppUser, - current_plan: PricingPlan, plan: PricingPlan, - className: str = "", - payment_provider: PaymentProvider | None = None, + payment_provider: PaymentProvider, ): - if plan.credits > current_plan.credits: - label, btn_type = ("Upgrade", "primary") - else: - label, btn_type = ("Downgrade", "secondary") - className += f" btn btn-theme btn-{btn_type}" - - key = f"change-sub-{plan.key}" match payment_provider: case PaymentProvider.STRIPE: + key = f"stripe-sub-{plan.key}" render_stripe_subscription_button( user=user, label=label, @@ -224,7 +249,19 @@ def update_subscription_button( ) case PaymentProvider.PAYPAL: render_paypal_subscription_button(plan=plan) - case _ if label == "Downgrade": + + +def _render_update_subscription_button( + label: PlanActionLabel, + *, + user: AppUser, + current_plan: PricingPlan, + plan: PricingPlan, + className: str = "", +): + key = f"change-sub-{plan.key}" + match label: + case "Downgrade": downgrade_modal = Modal( "Confirm downgrade", key=f"downgrade-plan-modal-{plan.key}", @@ -260,7 +297,12 @@ def update_subscription_button( downgrade_modal.close() case _: if st.button(label, className=className, key=key): - change_subscription(user, plan) + change_subscription( + user, + plan, + # when upgrading, charge the full new amount today: https://docs.stripe.com/billing/subscriptions/billing-cycle#reset-the-billing-cycle-to-the-current-time + billing_cycle_anchor="now", + ) def fmt_price(plan: PricingPlan) -> str: @@ -270,7 +312,7 @@ def fmt_price(plan: PricingPlan) -> str: return "Free" -def change_subscription(user: AppUser, new_plan: PricingPlan): +def change_subscription(user: AppUser, new_plan: PricingPlan, **kwargs): from routers.account import account_route from routers.account import payment_processing_route @@ -301,9 +343,7 @@ def change_subscription(user: AppUser, new_plan: PricingPlan): metadata={ settings.STRIPE_USER_SUBSCRIPTION_METADATA_FIELD: new_plan.key, }, - # charge the full new amount today, without prorations - # see: https://docs.stripe.com/billing/subscriptions/billing-cycle#reset-the-billing-cycle-to-the-current-time - billing_cycle_anchor="now", + **kwargs, proration_behavior="none", ) raise RedirectException( @@ -337,44 +377,57 @@ def payment_provider_radio(**props) -> str | None: ) -def render_addon_section(user: AppUser): - assert user.subscription - - st.write("# Purchase More Credits") +def render_addon_section(user: AppUser, selected_payment_provider: PaymentProvider): + if user.subscription: + st.write("# Purchase More Credits") + else: + st.write("# Purchase Credits") st.caption(f"Buy more credits. $1 per {settings.ADDON_CREDITS_PER_DOLLAR} credits") - provider = PaymentProvider(user.subscription.payment_provider) + if user.subscription: + provider = PaymentProvider(user.subscription.payment_provider) + else: + provider = selected_payment_provider match provider: case PaymentProvider.STRIPE: - for amount in settings.ADDON_AMOUNT_CHOICES: - render_stripe_addon_button(amount, user=user) + render_stripe_addon_buttons(user) case PaymentProvider.PAYPAL: - for amount in settings.ADDON_AMOUNT_CHOICES: - render_paypal_addon_button(amount) - st.div( - id="paypal-addon-buttons", - className="mt-2", - style={"width": "fit-content"}, - ) - st.div(id="paypal-result-message") + render_paypal_addon_buttons() -def render_paypal_addon_button(amount: int): - st.html( - f""" - - """ +def render_paypal_addon_buttons(): + selected_amt = st.horizontal_radio( + "", + settings.ADDON_AMOUNT_CHOICES, + format_func=lambda amt: f"${amt:,}", + checked_by_default=False, ) + if selected_amt: + st.js( + f"setPaypalAddonQuantity({int(selected_amt) * settings.ADDON_CREDITS_PER_DOLLAR})" + ) + st.div( + id="paypal-addon-buttons", + className="mt-2", + style={"width": "fit-content"}, + ) + st.div(id="paypal-result-message") + +def render_stripe_addon_buttons(user: AppUser): + for dollat_amt in settings.ADDON_AMOUNT_CHOICES: + render_stripe_addon_button(dollat_amt, user) -def render_stripe_addon_button(amount: int, user: AppUser): - confirm_purchase_modal = Modal("Confirm Purchase", key=f"confirm-purchase-{amount}") - if st.button(f"${amount:,}", type="primary"): - confirm_purchase_modal.open() + +def render_stripe_addon_button(dollat_amt: int, user: AppUser): + confirm_purchase_modal = Modal( + "Confirm Purchase", key=f"confirm-purchase-{dollat_amt}" + ) + if st.button(f"${dollat_amt:,}", type="primary"): + if user.subscription: + confirm_purchase_modal.open() + else: + stripe_addon_checkout_redirect(user, dollat_amt) if not confirm_purchase_modal.is_open(): return @@ -382,7 +435,7 @@ def render_stripe_addon_button(amount: int, user: AppUser): st.write( f""" Please confirm your purchase: - **{amount * settings.ADDON_CREDITS_PER_DOLLAR:,} credits for ${amount}**. + **{dollat_amt * settings.ADDON_CREDITS_PER_DOLLAR:,} credits for ${dollat_amt}**. """, className="py-4 d-block text-center", ) @@ -390,7 +443,7 @@ def render_stripe_addon_button(amount: int, user: AppUser): if st.session_state.get("--confirm-purchase"): success = st.run_in_thread( user.subscription.stripe_attempt_addon_purchase, - args=[amount], + args=[dollat_amt], placeholder="Processing payment...", ) if success is None: @@ -407,6 +460,27 @@ def render_stripe_addon_button(amount: int, user: AppUser): st.button("Buy", type="primary", key="--confirm-purchase") +def stripe_addon_checkout_redirect(user: AppUser, dollat_amt: int): + from routers.account import account_route + from routers.account import payment_processing_route + + line_item = available_subscriptions["addon"]["stripe"].copy() + line_item["quantity"] = dollat_amt * settings.ADDON_CREDITS_PER_DOLLAR + checkout_session = stripe.checkout.Session.create( + line_items=[line_item], + mode="payment", + success_url=get_route_url(payment_processing_route), + cancel_url=get_route_url(account_route), + customer=user.get_or_create_stripe_customer(), + invoice_creation={"enabled": True}, + allow_promotion_codes=True, + saved_payment_method_options={ + "payment_method_save": "enabled", + }, + ) + raise RedirectException(checkout_session.url, status_code=303) + + def render_stripe_subscription_button( *, label: str, @@ -419,16 +493,19 @@ def render_stripe_subscription_button( st.write("Stripe subscription not available") return - if st.button(label, type=btn_type, key=key): - create_stripe_checkout_session(user=user, plan=plan) + # IMPORTANT: key=... is needed here to maintain uniqueness + # of buttons with the same label. otherwise, all buttons + # will be the same to the server + if st.button(label, key=key, type=btn_type): + stripe_subscription_checkout_redirect(user=user, plan=plan) -def create_stripe_checkout_session(user: AppUser, plan: PricingPlan): +def stripe_subscription_checkout_redirect(user: AppUser, plan: PricingPlan): from routers.account import account_route from routers.account import payment_processing_route - if user.subscription and user.subscription.plan == plan.db_value: - # already subscribed to the same plan + if user.subscription: + # already subscribed to some plan return metadata = {settings.STRIPE_USER_SUBSCRIPTION_METADATA_FIELD: plan.key} @@ -543,16 +620,17 @@ def render_billing_history(user: AppUser, limit: int = 50): st.write("## Billing History", className="d-block") st.table( pd.DataFrame.from_records( - columns=[""] * 3, - data=[ - [ - txn.created_at.strftime("%m/%d/%Y"), - txn.note(), - f"${txn.charged_amount / 100:,.2f}", - ] + [ + { + "Date": txn.created_at.strftime("%m/%d/%Y"), + "Description": txn.reason_note(), + "Amount": f"-${txn.charged_amount / 100:,.2f}", + "Credits": f"+{txn.amount:,}", + "Balance": f"{txn.end_balance:,}", + } for txn in txns[:limit] - ], - ) + ] + ), ) if txns.count() > limit: st.caption(f"Showing only the most recent {limit} transactions.") diff --git a/daras_ai_v2/exceptions.py b/daras_ai_v2/exceptions.py index 78e3f053f..02acc4ec6 100644 --- a/daras_ai_v2/exceptions.py +++ b/daras_ai_v2/exceptions.py @@ -3,8 +3,16 @@ import typing import requests +from furl import furl from loguru import logger from requests import HTTPError +from starlette.status import HTTP_402_PAYMENT_REQUIRED + +from daras_ai_v2 import settings + +if typing.TYPE_CHECKING: + from bots.models import SavedRun + from bots.models import AppUser def raise_for_status(resp: requests.Response, is_user_url: bool = False): @@ -47,9 +55,12 @@ def _response_preview(resp: requests.Response) -> bytes: class UserError(Exception): - def __init__(self, message: str, sentry_level: str = "info"): + def __init__( + self, message: str, sentry_level: str = "info", status_code: int = None + ): self.message = message self.sentry_level = sentry_level + self.status_code = status_code super().__init__(message) @@ -57,6 +68,41 @@ class GPUError(UserError): pass +class InsufficientCredits(UserError): + def __init__(self, user: "AppUser", sr: "SavedRun"): + from daras_ai_v2.base import SUBMIT_AFTER_LOGIN_Q + + account_url = furl(settings.APP_BASE_URL) / "account/" + if user.is_anonymous: + account_url.query.params["next"] = sr.get_app_url( + query_params={SUBMIT_AFTER_LOGIN_Q: "1"}, + ) + # language=HTML + message = f""" +

+Doh! Please login to run more Gooey.AI workflows. +

+ +You’ll receive {settings.LOGIN_USER_FREE_CREDITS} Credits when you sign up via your phone #, Google, Apple or GitHub account +and can purchase more for $1/100 Credits. +""" + else: + # language=HTML + message = f""" +

+Doh! You’re out of Gooey.AI credits. +

+ +

+Please buy more to run more workflows. +

+ +We’re always on discord if you’ve got any questions. +""" + + super().__init__(message, status_code=HTTP_402_PAYMENT_REQUIRED) + + FFMPEG_ERR_MSG = ( "Unsupported File Format\n\n" "We encountered an issue processing your file as it appears to be in a format not supported by our system or may be corrupted. " diff --git a/daras_ai_v2/paypal.py b/daras_ai_v2/paypal.py index 4c5229b98..eae5042d4 100644 --- a/daras_ai_v2/paypal.py +++ b/daras_ai_v2/paypal.py @@ -155,6 +155,8 @@ class Subscription(PaypalResource): billing_info: BillingInfo | None def cancel(self, *, reason: str = "cancellation_requested") -> None: + if self.status in ["CANCELLED", "EXPIRED"]: + return r = requests.post( str(self.get_resource_url() / "cancel"), headers=get_default_headers(), diff --git a/daras_ai_v2/settings.py b/daras_ai_v2/settings.py index 58878a9af..9b48a1c2e 100644 --- a/daras_ai_v2/settings.py +++ b/daras_ai_v2/settings.py @@ -263,7 +263,7 @@ ADMIN_EMAILS = config("ADMIN_EMAILS", cast=Csv(), default="") SUPPORT_EMAIL = "Gooey.AI Support " SALES_EMAIL = "Gooey.AI Sales " -SEND_RUN_EMAIL_AFTER_SEC = config("SEND_RUN_EMAIL_AFTER_SEC", 60) +SEND_RUN_EMAIL_AFTER_SEC = config("SEND_RUN_EMAIL_AFTER_SEC", 5) DISALLOWED_TITLE_SLUGS = config("DISALLOWED_TITLE_SLUGS", cast=Csv(), default="") + [ # tab names @@ -295,9 +295,10 @@ LOW_BALANCE_EMAIL_ENABLED = config("LOW_BALANCE_EMAIL_ENABLED", True, cast=bool) STRIPE_SECRET_KEY = config("STRIPE_SECRET_KEY", None) +STRIPE_ENDPOINT_SECRET = config("STRIPE_ENDPOINT_SECRET", None) stripe.api_key = STRIPE_SECRET_KEY + STRIPE_USER_SUBSCRIPTION_METADATA_FIELD: str = "subscription_key" -STRIPE_ENDPOINT_SECRET = config("STRIPE_ENDPOINT_SECRET", None) STRIPE_ADDON_PRODUCT_NAME = config( "STRIPE_ADDON_PRODUCT_NAME", "Gooey.AI Add-on Credits" ) diff --git a/payments/auto_recharge.py b/payments/auto_recharge.py new file mode 100644 index 000000000..1aa414cfb --- /dev/null +++ b/payments/auto_recharge.py @@ -0,0 +1,119 @@ +import traceback + +import sentry_sdk +from loguru import logger + +from app_users.models import AppUser, PaymentProvider +from daras_ai_v2.redis_cache import redis_lock +from payments.tasks import ( + send_monthly_budget_reached_email, + send_auto_recharge_failed_email, +) + + +class AutoRechargeException(Exception): + pass + + +class MonthlyBudgetReachedException(AutoRechargeException): + def __init__(self, *args, budget: int, spending: float, **kwargs): + super().__init__(*args, **kwargs) + self.budget = budget + self.spending = spending + + +class PaymentFailedException(AutoRechargeException): + pass + + +class AutoRechargeCooldownException(AutoRechargeException): + pass + + +def should_attempt_auto_recharge(user: AppUser): + return ( + user.subscription + and user.subscription.auto_recharge_enabled + and user.subscription.payment_provider + and user.balance < user.subscription.auto_recharge_balance_threshold + ) + + +def run_auto_recharge_gracefully(user: AppUser): + """ + Wrapper over _auto_recharge_user, that handles exceptions so that it can: + - log exceptions + - send emails when auto-recharge fails + - not retry if this is run as a background task + + Meant to be used in conjunction with should_attempt_auto_recharge + """ + try: + with redis_lock(f"gooey/auto_recharge_user/v1/{user.uid}"): + _auto_recharge_user(user) + except AutoRechargeCooldownException as e: + logger.info( + f"Rejected auto-recharge because auto-recharge is in cooldown period for user" + f"{user=}, {e=}" + ) + except MonthlyBudgetReachedException as e: + send_monthly_budget_reached_email(user) + logger.info( + f"Rejected auto-recharge because user has reached monthly budget" + f"{user=}, spending=${e.spending}, budget=${e.budget}" + ) + except Exception as e: + traceback.print_exc() + sentry_sdk.capture_exception(e) + send_auto_recharge_failed_email(user) + + +def _auto_recharge_user(user: AppUser): + """ + Returns whether a charge was attempted + """ + from payments.webhooks import StripeWebhookHandler + + assert ( + user.subscription.payment_provider == PaymentProvider.STRIPE + ), "Auto recharge is only supported with Stripe" + + # check for monthly budget + dollars_spent = user.get_dollars_spent_this_month() + if ( + dollars_spent + user.subscription.auto_recharge_topup_amount + > user.subscription.monthly_spending_budget + ): + raise MonthlyBudgetReachedException( + "Performing this top-up would exceed your monthly recharge budget", + budget=user.subscription.monthly_spending_budget, + spending=dollars_spent, + ) + + try: + invoice = user.subscription.stripe_get_or_create_auto_invoice( + amount_in_dollars=user.subscription.auto_recharge_topup_amount, + metadata_key="auto_recharge", + ) + except Exception as e: + raise PaymentFailedException("Failed to create auto-recharge invoice") from e + + # recent invoice was already paid + if invoice.status == "paid": + raise AutoRechargeCooldownException( + "An auto recharge invoice was paid recently" + ) + + # get default payment method and attempt payment + assert invoice.status == "open" # sanity check + pm = user.subscription.stripe_get_default_payment_method() + + try: + invoice_data = invoice.pay(payment_method=pm) + except Exception as e: + raise PaymentFailedException( + "Payment failed when attempting to auto-recharge" + ) from e + else: + assert invoice_data.paid + StripeWebhookHandler.handle_invoice_paid(uid=user.uid, invoice=invoice_data) diff --git a/payments/migrations/0004_alter_subscription_auto_recharge_balance_threshold.py b/payments/migrations/0004_alter_subscription_auto_recharge_balance_threshold.py new file mode 100644 index 000000000..d5a1e922a --- /dev/null +++ b/payments/migrations/0004_alter_subscription_auto_recharge_balance_threshold.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.7 on 2024-06-10 09:21 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('payments', '0003_alter_subscription_external_id_and_more'), + ] + + operations = [ + migrations.AlterField( + model_name='subscription', + name='auto_recharge_balance_threshold', + field=models.IntegerField(blank=True, null=True), + ), + ] diff --git a/payments/migrations/0005_alter_subscription_plan.py b/payments/migrations/0005_alter_subscription_plan.py new file mode 100644 index 000000000..bfb74dfe3 --- /dev/null +++ b/payments/migrations/0005_alter_subscription_plan.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.7 on 2024-07-14 08:52 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('payments', '0004_alter_subscription_auto_recharge_balance_threshold'), + ] + + operations = [ + migrations.AlterField( + model_name='subscription', + name='plan', + field=models.IntegerField(choices=[(1, 'Basic Plan'), (2, 'Premium Plan'), (3, 'Starter'), (4, 'Creator'), (5, 'Business'), (6, 'Enterprise / Agency')]), + ), + ] diff --git a/payments/models.py b/payments/models.py index 79ef55ffa..3ce730793 100644 --- a/payments/models.py +++ b/payments/models.py @@ -1,5 +1,6 @@ from __future__ import annotations +import time import typing import stripe @@ -43,19 +44,23 @@ class Subscription(models.Model): blank=True, ) auto_recharge_enabled = models.BooleanField(default=True) - auto_recharge_balance_threshold = models.IntegerField() auto_recharge_topup_amount = models.IntegerField( - default=settings.ADDON_AMOUNT_CHOICES[0], + default=settings.ADDON_AMOUNT_CHOICES[0] ) + auto_recharge_balance_threshold = models.IntegerField( + null=True, blank=True + ) # dynamic default (see: full_clean) monthly_spending_budget = models.IntegerField( null=True, blank=True, help_text="In USD, pause auto-recharge just before the spending exceeds this amount in a calendar month", + # dynamic default (see: full_clean) ) monthly_spending_notification_threshold = models.IntegerField( null=True, blank=True, help_text="In USD, send an email when spending crosses this threshold in a calendar month", + # dynamic default (see: full_clean) ) monthly_spending_notification_sent_at = models.DateTimeField(null=True, blank=True) @@ -140,7 +145,10 @@ def get_next_invoice_timestamp(self) -> float | None: return period_end elif self.payment_provider == PaymentProvider.PAYPAL: subscription = paypal.Subscription.retrieve(self.external_id) - if not subscription.billing_info: + if ( + not subscription.billing_info + or not subscription.billing_info.next_billing_time + ): return None return subscription.billing_info.next_billing_time.timestamp() else: @@ -196,9 +204,9 @@ def stripe_get_or_create_auto_invoice( Fetches the relevant invoice, or creates one if it doesn't exist. This is the fallback order: - - Fetch an open invoice with metadata_key in the metadata - - Fetch a $metadata_key invoice that was recently paid - - Create an invoice with amount=amount_in_dollars and $metadata_key + - Fetch an open invoice that has `metadata_key` set + - Fetch a `metadata_key` invoice that was recently paid + - Create an invoice with amount=`amount_in_dollars` and `metadata_key` set to true """ customer_id = self.stripe_get_customer_id() invoices = stripe.Invoice.list( @@ -214,7 +222,7 @@ def stripe_get_or_create_auto_invoice( for inv in invoices: if ( inv.status == "paid" - and timezone.now().timestamp() - inv.created + and abs(time.time() - inv.created) < settings.AUTO_RECHARGE_COOLDOWN_SECONDS ): return inv @@ -256,7 +264,7 @@ def stripe_get_customer_id(self) -> str: raise ValueError("Invalid Payment Provider") def stripe_attempt_addon_purchase(self, amount_in_dollars: int) -> bool: - from routers.stripe import handle_invoice_paid + from payments.webhooks import StripeWebhookHandler invoice = self.stripe_create_auto_invoice( amount_in_dollars=amount_in_dollars, @@ -268,7 +276,7 @@ def stripe_attempt_addon_purchase(self, amount_in_dollars: int) -> bool: invoice = invoice.pay(payment_method=pm) if not invoice.paid: return False - handle_invoice_paid(self.user.uid, invoice) + StripeWebhookHandler.handle_invoice_paid(self.user.uid, invoice) return True def get_external_management_url(self) -> str: diff --git a/payments/plans.py b/payments/plans.py index b2dd6fcdb..badfd2c1c 100644 --- a/payments/plans.py +++ b/payments/plans.py @@ -211,7 +211,7 @@ def __lt__(self, other: PricingPlan) -> bool: @classmethod def db_choices(cls): - return [(plan.db_value, plan.name) for plan in cls] + return [(plan.db_value, plan.title) for plan in cls] @classmethod def from_sub(cls, sub: "Subscription") -> PricingPlan: @@ -240,7 +240,7 @@ def get_by_paypal_plan_id(cls, plan_id: str) -> PricingPlan | None: return plan @classmethod - def get_by_key(cls, key: str): + def get_by_key(cls, key: str) -> PricingPlan | None: for plan in cls: if plan.key == key: return plan diff --git a/payments/tasks.py b/payments/tasks.py index d5177c8e3..6c8b046d5 100644 --- a/payments/tasks.py +++ b/payments/tasks.py @@ -10,70 +10,72 @@ @app.task -def send_email_budget_reached(user_id: int): +def send_monthly_spending_notification_email(user_id: int): from routers.account import account_route user = AppUser.objects.get(id=user_id) if not user.email: + logger.error(f"User doesn't have an email: {user=}") return - email_body = templates.get_template("monthly_budget_reached_email.html").render( - user=user, - account_url=get_route_url(account_route), - ) + threshold = user.subscription.monthly_spending_notification_threshold + send_email_via_postmark( from_address=settings.SUPPORT_EMAIL, to_address=user.email, - subject="[Gooey.AI] Monthly Budget Reached", - html_body=email_body, + subject=f"[Gooey.AI] Monthly spending has exceeded ${threshold}", + html_body=templates.get_template( + "monthly_spending_notification_threshold_email.html" + ).render( + user=user, + account_url=get_route_url(account_route), + ), ) - user.subscription.monthly_budget_email_sent_at = timezone.now() - user.subscription.save(update_fields=["monthly_budget_email_sent_at"]) + # IMPORTANT: always use update_fields=... / select_for_update when updating + # subscription info. We don't want to overwrite other changes made to + # subscription during the same time + user.subscription.monthly_spending_notification_sent_at = timezone.now() + user.subscription.save(update_fields=["monthly_spending_notification_sent_at"]) -@app.task -def send_email_auto_recharge_failed(user_id: int): +def send_monthly_budget_reached_email(user: AppUser): from routers.account import account_route - user = AppUser.objects.get(id=user_id) if not user.email: return - email_body = templates.get_template("auto_recharge_failed_email.html").render( + email_body = templates.get_template("monthly_budget_reached_email.html").render( user=user, account_url=get_route_url(account_route), ) send_email_via_postmark( from_address=settings.SUPPORT_EMAIL, to_address=user.email, - subject="[Gooey.AI] Auto-Recharge failed", + subject="[Gooey.AI] Monthly Budget Reached", html_body=email_body, ) + # IMPORTANT: always use update_fields=... when updating subscription + # info. We don't want to overwrite other changes made to subscription + # during the same time + user.subscription.monthly_budget_email_sent_at = timezone.now() + user.subscription.save(update_fields=["monthly_budget_email_sent_at"]) -@app.task -def send_monthly_spending_notification_email(user_id: int): + +def send_auto_recharge_failed_email(user: AppUser): from routers.account import account_route - user = AppUser.objects.get(id=user_id) if not user.email: - logger.error(f"User doesn't have an email: {user=}") return - threshold = user.subscription.monthly_spending_notification_threshold - + email_body = templates.get_template("auto_recharge_failed_email.html").render( + user=user, + account_url=get_route_url(account_route), + ) send_email_via_postmark( from_address=settings.SUPPORT_EMAIL, to_address=user.email, - subject=f"[Gooey.AI] Monthly spending has exceeded ${threshold}", - html_body=templates.get_template( - "monthly_spending_notification_threshold_email.html" - ).render( - user=user, - account_url=get_route_url(account_route), - ), + subject="[Gooey.AI] Auto-Recharge failed", + html_body=email_body, ) - - user.subscription.monthly_spending_notification_sent_at = timezone.now() - user.subscription.save(update_fields=["monthly_spending_notification_sent_at"]) diff --git a/payments/webhooks.py b/payments/webhooks.py new file mode 100644 index 000000000..b30cae120 --- /dev/null +++ b/payments/webhooks.py @@ -0,0 +1,233 @@ +import stripe +from django.db import transaction +from loguru import logger + +from app_users.models import AppUser, PaymentProvider, TransactionReason +from daras_ai_v2 import paypal +from .models import Subscription +from .plans import PricingPlan +from .tasks import send_monthly_spending_notification_email + + +class PaypalWebhookHandler: + PROVIDER = PaymentProvider.PAYPAL + + @classmethod + def handle_sale_completed(cls, sale: paypal.Sale): + if not sale.billing_agreement_id: + logger.info(f"sale {sale} is not a subscription sale... skipping") + return + + pp_sub = paypal.Subscription.retrieve(sale.billing_agreement_id) + assert pp_sub.custom_id, "pp_sub is missing uid" + assert pp_sub.plan_id, "pp_sub is missing plan ID" + + plan = PricingPlan.get_by_paypal_plan_id(pp_sub.plan_id) + assert plan, f"Plan {pp_sub.plan_id} not found" + + charged_dollars = int(float(sale.amount.total)) # convert to dollars + if charged_dollars != plan.monthly_charge: + # log so that we can investigate, and record the payment as usual + logger.critical( + f"paypal: charged amount ${charged_dollars} does not match plan's monthly charge ${plan.monthly_charge}" + ) + + uid = pp_sub.custom_id + add_balance_for_payment( + uid=uid, + amount=plan.credits, + invoice_id=sale.id, + payment_provider=cls.PROVIDER, + charged_amount=charged_dollars * 100, + reason=TransactionReason.SUBSCRIBE, + plan=plan.db_value, + ) + + @classmethod + def handle_subscription_updated(cls, pp_sub: paypal.Subscription): + logger.info(f"Paypal subscription updated {pp_sub.id}") + + assert pp_sub.custom_id, f"PayPal subscription {pp_sub.id} is missing uid" + assert pp_sub.plan_id, f"PayPal subscription {pp_sub.id} is missing plan ID" + + plan = PricingPlan.get_by_paypal_plan_id(pp_sub.plan_id) + assert plan, f"Plan with id={pp_sub.plan_id} not found" + + if pp_sub.status.lower() != "active": + logger.info( + "Subscription is not active. Ignoring event", subscription=pp_sub + ) + return + + _set_user_subscription( + provider=cls.PROVIDER, + plan=plan, + uid=pp_sub.custom_id, + external_id=pp_sub.id, + ) + + @classmethod + def handle_subscription_cancelled(cls, pp_sub: paypal.Subscription): + assert pp_sub.custom_id, f"PayPal subscription {pp_sub.id} is missing uid" + _remove_subscription_for_user( + provider=cls.PROVIDER, uid=pp_sub.custom_id, external_id=pp_sub.id + ) + + +class StripeWebhookHandler: + PROVIDER = PaymentProvider.STRIPE + + @classmethod + def handle_invoice_paid(cls, uid: str, invoice: stripe.Invoice): + kwargs = {} + if invoice.subscription: + kwargs["plan"] = PricingPlan.get_by_key( + invoice.subscription_details.metadata.get("subscription_key") + ).db_value + match invoice.billing_reason: + case "subscription_create": + reason = TransactionReason.SUBSCRIPTION_CREATE + case "subscription_cycle": + reason = TransactionReason.SUBSCRIPTION_CYCLE + case "subscription_update": + reason = TransactionReason.SUBSCRIPTION_UPDATE + case _: + reason = TransactionReason.SUBSCRIBE + elif invoice.metadata and invoice.metadata.get("auto_recharge"): + reason = TransactionReason.AUTO_RECHARGE + else: + reason = TransactionReason.ADDON + add_balance_for_payment( + uid=uid, + amount=invoice.lines.data[0].quantity, + invoice_id=invoice.id, + payment_provider=cls.PROVIDER, + charged_amount=invoice.lines.data[0].amount, + reason=reason, + **kwargs, + ) + + @classmethod + def handle_checkout_session_completed(cls, uid: str, session_data): + if setup_intent_id := session_data.get("setup_intent") is None: + # not a setup mode checkout -- do nothing + return + setup_intent = stripe.SetupIntent.retrieve(setup_intent_id) + + # subscription_id was passed to metadata when creating the session + sub_id = setup_intent.metadata["subscription_id"] + assert ( + sub_id + ), f"subscription_id is missing in setup_intent metadata {setup_intent}" + + stripe.Subscription.modify( + sub_id, default_payment_method=setup_intent.payment_method + ) + + @classmethod + def handle_subscription_updated(cls, uid: str, stripe_sub: stripe.Subscription): + logger.info(f"Stripe subscription updated: {stripe_sub.id}") + + assert stripe_sub.plan, f"Stripe subscription {stripe_sub.id} is missing plan" + assert ( + stripe_sub.plan.product + ), f"Stripe subscription {stripe_sub.id} is missing product" + + product = stripe.Product.retrieve(stripe_sub.plan.product) + plan = PricingPlan.get_by_stripe_product(product) + if not plan: + raise Exception( + f"PricingPlan not found for product {stripe_sub.plan.product}" + ) + + if stripe_sub.status.lower() != "active": + logger.info( + "Subscription is not active. Ignoring event", subscription=stripe_sub + ) + return + + _set_user_subscription( + provider=cls.PROVIDER, + plan=plan, + uid=uid, + external_id=stripe_sub.id, + ) + + @classmethod + def handle_subscription_cancelled(cls, uid: str, stripe_sub): + logger.info(f"Stripe subscription cancelled: {stripe_sub.id}") + _remove_subscription_for_user( + provider=cls.PROVIDER, uid=uid, external_id=stripe_sub.id + ) + + +def add_balance_for_payment( + *, + uid: str, + amount: int, + invoice_id: str, + payment_provider: PaymentProvider, + charged_amount: int, + **kwargs, +): + user = AppUser.objects.get_or_create_from_uid(uid)[0] + user.add_balance( + amount=amount, + invoice_id=invoice_id, + charged_amount=charged_amount, + payment_provider=payment_provider, + **kwargs, + ) + + if not user.is_paying: + user.is_paying = True + user.save(update_fields=["is_paying"]) + + if ( + user.subscription + and user.subscription.should_send_monthly_spending_notification() + ): + send_monthly_spending_notification_email.delay(user.id) + + +def _set_user_subscription( + *, provider: PaymentProvider, plan: PricingPlan, uid: str, external_id: str +): + with transaction.atomic(): + subscription, created = Subscription.objects.get_or_create( + payment_provider=provider, + external_id=external_id, + defaults=dict(plan=plan.db_value), + ) + subscription.plan = plan.db_value + subscription.full_clean() + subscription.save() + + user = AppUser.objects.get_or_create_from_uid(uid)[0] + existing = user.subscription + + user.subscription = subscription + user.save(update_fields=["subscription"]) + + if not existing: + return + + # cancel existing subscription if it's not the same as the new one + if existing.external_id != external_id: + existing.cancel() + + # delete old db record if it exists + if existing.id != subscription.id: + existing.delete() + + +def _remove_subscription_for_user( + *, uid: str, provider: PaymentProvider, external_id: str +): + AppUser.objects.filter( + uid=uid, + subscription__payment_provider=provider, + subscription__external_id=external_id, + ).update( + subscription=None, + ) diff --git a/poetry.lock b/poetry.lock index c63dd4849..9d7b61cb3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5122,17 +5122,18 @@ test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeo [[package]] name = "sentry-sdk" -version = "1.34.0" +version = "1.45.0" description = "Python client for Sentry (https://sentry.io)" optional = false python-versions = "*" files = [ - {file = "sentry-sdk-1.34.0.tar.gz", hash = "sha256:e5d0d2b25931d88fa10986da59d941ac6037f742ab6ff2fce4143a27981d60c3"}, - {file = "sentry_sdk-1.34.0-py2.py3-none-any.whl", hash = "sha256:76dd087f38062ac6c1e30ed6feb533ee0037ff9e709974802db7b5dbf2e5db21"}, + {file = "sentry-sdk-1.45.0.tar.gz", hash = "sha256:509aa9678c0512344ca886281766c2e538682f8acfa50fd8d405f8c417ad0625"}, + {file = "sentry_sdk-1.45.0-py2.py3-none-any.whl", hash = "sha256:1ce29e30240cc289a027011103a8c83885b15ef2f316a60bcc7c5300afa144f1"}, ] [package.dependencies] certifi = "*" +loguru = {version = ">=0.5", optional = true, markers = "extra == \"loguru\""} urllib3 = {version = ">=1.26.11", markers = "python_version >= \"3.6\""} [package.extras] @@ -5142,6 +5143,7 @@ asyncpg = ["asyncpg (>=0.23)"] beam = ["apache-beam (>=2.12)"] bottle = ["bottle (>=0.12.13)"] celery = ["celery (>=3)"] +celery-redbeat = ["celery-redbeat (>=2)"] chalice = ["chalice (>=1.16.0)"] clickhouse-driver = ["clickhouse-driver (>=0.2.0)"] django = ["django (>=1.8)"] @@ -5152,6 +5154,7 @@ grpcio = ["grpcio (>=1.21.1)"] httpx = ["httpx (>=0.16.0)"] huey = ["huey (>=2)"] loguru = ["loguru (>=0.5)"] +openai = ["openai (>=1.0.0)", "tiktoken (>=0.3.0)"] opentelemetry = ["opentelemetry-distro (>=0.35b0)"] opentelemetry-experimental = ["opentelemetry-distro (>=0.40b0,<1.0)", "opentelemetry-instrumentation-aiohttp-client (>=0.40b0,<1.0)", "opentelemetry-instrumentation-django (>=0.40b0,<1.0)", "opentelemetry-instrumentation-fastapi (>=0.40b0,<1.0)", "opentelemetry-instrumentation-flask (>=0.40b0,<1.0)", "opentelemetry-instrumentation-requests (>=0.40b0,<1.0)", "opentelemetry-instrumentation-sqlite3 (>=0.40b0,<1.0)", "opentelemetry-instrumentation-urllib (>=0.40b0,<1.0)"] pure-eval = ["asttokens", "executing", "pure-eval"] @@ -5445,17 +5448,18 @@ snowflake = ["snowflake-connector-python (>=2.8.0)", "snowflake-snowpark-python [[package]] name = "stripe" -version = "5.5.0" +version = "10.3.0" description = "Python bindings for the Stripe API" optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +python-versions = ">=3.6" files = [ - {file = "stripe-5.5.0-py2.py3-none-any.whl", hash = "sha256:b4947da66dbb3de8969004ba6398f9a019c6b1b3ffe6aa88d5b07ac560a52b28"}, - {file = "stripe-5.5.0.tar.gz", hash = "sha256:04a9732b37a46228ecf0e496163a3edd93596b0e6200029fbc48911638627e19"}, + {file = "stripe-10.3.0-py2.py3-none-any.whl", hash = "sha256:95aa10d34e325cb6a19784412d6196621442c278b0c9cd3fe7be2a7ef180c2f8"}, + {file = "stripe-10.3.0.tar.gz", hash = "sha256:56515faf0cbee82f27d9b066403988a107301fc80767500be9789a25d65f2bae"}, ] [package.dependencies] requests = {version = ">=2.20", markers = "python_version >= \"3.0\""} +typing-extensions = {version = ">=4.5.0", markers = "python_version >= \"3.7\""} [[package]] name = "tabulate" @@ -6442,4 +6446,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "cc79f3b414323945ade371a12c4071eb50b9988715f0d094e4e9ef34008c3fe2" +content-hash = "2777d2a014b924fe8a7c2dfe63ebccbb55b5572abea44515898ca8e7fd7a17b0" diff --git a/pyproject.toml b/pyproject.toml index 8f81fb681..a470e162b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,11 +31,10 @@ google-cloud-texttospeech = "^2.12.1" Wand = "^0.6.10" readability-lxml = "^0.8.1" transformers = "^4.24.0" -stripe = "^5.0.0" +stripe = "^10.3.0" python-multipart = "^0.0.5" html-sanitizer = "^1.9.3" plotly = "^5.11.0" -sentry-sdk = "^1.12.0" httpx = "^0.23.1" pyquery = "^1.4.3" redis = "^4.5.1" @@ -86,6 +85,7 @@ emoji = "^2.10.1" pyvespa = "^0.39.0" anthropic = "^0.25.5" azure-cognitiveservices-speech = "^1.37.0" +sentry-sdk = {version = "1.45.0", extras = ["loguru"]} [tool.poetry.group.dev.dependencies] watchdog = "^2.1.9" diff --git a/routers/account.py b/routers/account.py index 4df7afcf1..1d99d3d0a 100644 --- a/routers/account.py +++ b/routers/account.py @@ -2,29 +2,24 @@ from contextlib import contextmanager from enum import Enum -from django.db import transaction from fastapi import APIRouter from fastapi.requests import Request from furl import furl from loguru import logger +from requests.models import HTTPError import gooey_ui as st -from app_users.models import AppUser, PaymentProvider from bots.models import PublishedRun, PublishedRunVisibility, Workflow from daras_ai_v2 import icons, paypal from daras_ai_v2.base import RedirectException from daras_ai_v2.billing import billing_page -from daras_ai_v2.fastapi_tricks import ( - get_route_path, - get_route_url, -) +from daras_ai_v2.fastapi_tricks import get_route_path, get_route_url from daras_ai_v2.grid_layout_widget import grid_layout from daras_ai_v2.manage_api_keys_widget import manage_api_keys from daras_ai_v2.meta_content import raw_build_meta_tags from daras_ai_v2.profiles import edit_user_profile_page from gooey_ui.components.pills import pill -from payments.models import Subscription -from payments.plans import PricingPlan +from payments.webhooks import PaypalWebhookHandler from routers.root import page_wrapper, get_og_url_path app = APIRouter() @@ -33,20 +28,25 @@ @app.post("/payment-processing/") @st.route def payment_processing_route( - request: Request, provider: str = None, subscription_id: str = None + request: Request, provider: str | None = None, subscription_id: str | None = None ): - subtext = None waiting_time_sec = 3 + subtext = None if provider == "paypal": - if sub_id := subscription_id: - sub = paypal.Subscription.retrieve(sub_id) - paypal_handle_subscription_updated(sub) + success = st.run_in_thread( + threaded_paypal_handle_subscription_updated, + args=[subscription_id], + ) + if success: + # immediately redirect + waiting_time_sec = 0 else: + # either failed or still running. in either case, wait 30s before redirecting + waiting_time_sec = 30 subtext = ( "PayPal transactions take up to a minute to reflect in your account" ) - waiting_time_sec = 30 with page_wrapper(request, className="m-auto"): with st.center(): @@ -56,7 +56,9 @@ def payment_processing_route( style=dict(height="3rem", width="3rem"), ) st.write("# Processing payment...") - st.caption(subtext) + + if subtext: + st.caption(subtext) st.js( # language=JavaScript @@ -229,36 +231,14 @@ def account_page_wrapper(request: Request, current_tab: TabData): yield -@transaction.atomic -def paypal_handle_subscription_updated(subscription: paypal.Subscription): - logger.info("Subscription updated") - - plan = PricingPlan.get_by_paypal_plan_id(subscription.plan_id) - if not plan: - logger.error(f"Invalid plan ID: {subscription.plan_id}") - return - - if not subscription.status == "ACTIVE": - logger.warning(f"Subscription {subscription.id} is not active") - return - - user = AppUser.objects.get(uid=subscription.custom_id) - if user.subscription and ( - user.subscription.payment_provider != PaymentProvider.PAYPAL - or user.subscription.external_id != subscription.id - ): - logger.warning( - f"User {user} has different existing subscription {user.subscription}. Cancelling that..." - ) - user.subscription.cancel() - user.subscription.delete() - elif not user.subscription: - user.subscription = Subscription() - - user.subscription.plan = plan.db_value - user.subscription.payment_provider = PaymentProvider.PAYPAL - user.subscription.external_id = subscription.id - - user.subscription.full_clean() - user.subscription.save() - user.save(update_fields=["subscription"]) +def threaded_paypal_handle_subscription_updated(subscription_id: str) -> bool: + """ + Always returns True when completed (for use in st.run_in_thread()) + """ + try: + subscription = paypal.Subscription.retrieve(subscription_id) + PaypalWebhookHandler.handle_subscription_updated(subscription) + except HTTPError: + logger.exception(f"Unexpected PayPal error for sub: {subscription_id}") + return False + return True diff --git a/routers/api.py b/routers/api.py index 8127b1b24..30d834943 100644 --- a/routers/api.py +++ b/routers/api.py @@ -10,6 +10,7 @@ from fastapi import Form from fastapi import HTTPException from fastapi import Response +from fastapi.exceptions import RequestValidationError from furl import furl from pydantic import BaseModel, Field from pydantic import ValidationError @@ -18,16 +19,20 @@ from starlette.datastructures import FormData from starlette.datastructures import UploadFile from starlette.requests import Request +from starlette.status import ( + HTTP_402_PAYMENT_REQUIRED, + HTTP_429_TOO_MANY_REQUESTS, + HTTP_500_INTERNAL_SERVER_ERROR, + HTTP_400_BAD_REQUEST, +) import gooey_ui as st from app_users.models import AppUser from auth.token_authentication import api_auth_header from bots.models import RetentionPolicy -from celeryapp.tasks import auto_recharge from daras_ai.image_input import upload_file_from_bytes from daras_ai_v2 import settings from daras_ai_v2.all_pages import all_api_pages -from daras_ai_v2.auto_recharge import user_should_auto_recharge from daras_ai_v2.base import ( BasePage, RecipeRunState, @@ -35,7 +40,6 @@ from daras_ai_v2.fastapi_tricks import fastapi_request_form from functions.models import CalledFunctionResponse from gooeysite.bg_db_conn import get_celery_result_db_safe -from routers.account import AccountTabs app = APIRouter() @@ -92,7 +96,7 @@ class AsyncApiResponseModelV3(BaseResponseModelV3): class AsyncStatusResponseModelV3(BaseResponseModelV3, typing.Generic[O]): - run_time_sec: int = Field(description="Total run time in seconds") + run_time_sec: float = Field(description="Total run time in seconds") status: RecipeRunState = Field(description="Status of the run") detail: str = Field( description="Details about the status of the run as a human readable string" @@ -129,14 +133,17 @@ def script_to_api(page_cls: typing.Type[BasePage]): ) common_errs = { - 402: {"model": GenericErrorResponse}, - 429: {"model": GenericErrorResponse}, + HTTP_402_PAYMENT_REQUIRED: {"model": GenericErrorResponse}, + HTTP_429_TOO_MANY_REQUESTS: {"model": GenericErrorResponse}, } @app.post( os.path.join(endpoint, ""), response_model=response_model, - responses={500: {"model": FailedReponseModelV2}, **common_errs}, + responses={ + HTTP_500_INTERNAL_SERVER_ERROR: {"model": FailedReponseModelV2}, + **common_errs, + }, operation_id=page_cls.slug_versions[0], tags=[page_cls.title], name=page_cls.title + " (v2 sync)", @@ -144,7 +151,10 @@ def script_to_api(page_cls: typing.Type[BasePage]): @app.post( endpoint, response_model=response_model, - responses={500: {"model": FailedReponseModelV2}, **common_errs}, + responses={ + HTTP_500_INTERNAL_SERVER_ERROR: {"model": FailedReponseModelV2}, + **common_errs, + }, include_in_schema=False, ) def run_api_json( @@ -163,13 +173,21 @@ def run_api_json( @app.post( os.path.join(endpoint, "form/"), response_model=response_model, - responses={500: {"model": FailedReponseModelV2}, **common_errs}, + responses={ + HTTP_500_INTERNAL_SERVER_ERROR: {"model": FailedReponseModelV2}, + HTTP_400_BAD_REQUEST: {"model": GenericErrorResponse}, + **common_errs, + }, include_in_schema=False, ) @app.post( os.path.join(endpoint, "form"), response_model=response_model, - responses={500: {"model": FailedReponseModelV2}, **common_errs}, + responses={ + HTTP_500_INTERNAL_SERVER_ERROR: {"model": FailedReponseModelV2}, + HTTP_400_BAD_REQUEST: {"model": GenericErrorResponse}, + **common_errs, + }, include_in_schema=False, ) def run_api_form( @@ -223,16 +241,22 @@ def run_api_json_async( @app.post( os.path.join(endpoint, "async/form/"), response_model=response_model, - responses=common_errs, + responses={ + HTTP_400_BAD_REQUEST: {"model": GenericErrorResponse}, + **common_errs, + }, include_in_schema=False, ) @app.post( os.path.join(endpoint, "async/form"), response_model=response_model, - responses=common_errs, + responses={ + HTTP_400_BAD_REQUEST: {"model": GenericErrorResponse}, + **common_errs, + }, include_in_schema=False, ) - def run_api_form( + def run_api_form_async( request: Request, response: Response, user: AppUser = Depends(api_auth_header), @@ -278,9 +302,10 @@ def get_run_status( "created_at": sr.created_at.isoformat(), "run_time_sec": sr.run_time.total_seconds(), } - if sr.error_msg: + if sr.error_code: + raise HTTPException(sr.error_code, detail=ret | {"error": sr.error_msg}) + elif sr.error_msg: ret |= {"status": "failed", "detail": sr.error_msg} - return ret else: status = self.get_run_state(sr.to_dict()) ret |= {"detail": sr.run_status or "", "status": status} @@ -310,7 +335,10 @@ def _parse_form_data( try: is_str = request_model.schema()["properties"][key]["type"] == "string" except KeyError: - raise HTTPException(status_code=400, detail=f'Inavlid file field "{key}"') + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=dict(error=f'Inavlid file field "{key}"'), + ) if is_str: page_request_data[key] = urls[0] else: @@ -319,7 +347,7 @@ def _parse_form_data( try: page_request = request_model.parse_obj(page_request_data) except ValidationError as e: - raise HTTPException(status_code=422, detail=e.errors()) + raise RequestValidationError(e.errors(), body=page_request_data) return page_request @@ -373,24 +401,15 @@ def submit_api_call( st.set_session_state(state) st.set_query_params(query_params) - if user_should_auto_recharge(self.request.user): - auto_recharge.delay(user_id=self.request.user.id) - - # check the balance - if settings.CREDITS_TO_DEDUCT_PER_RUN and not self.check_credits(): - account_url = furl(settings.APP_BASE_URL) / AccountTabs.billing.url_path - raise HTTPException( - status_code=402, - detail=dict( - error=f"Doh! You need to purchase additional credits to run more Gooey.AI recipes: {account_url}" - ), - ) # create a new run - sr = self.create_new_run( - enable_rate_limits=enable_rate_limits, - is_api_call=True, - retention_policy=retention_policy or RetentionPolicy.keep, - ) + try: + sr = self.create_new_run( + enable_rate_limits=enable_rate_limits, + is_api_call=True, + retention_policy=retention_policy or RetentionPolicy.keep, + ) + except ValidationError as e: + raise RequestValidationError(e.errors(), body=request_body) # submit the task result = self.call_runner_task(sr) return self, result, sr.run_id, sr.uid @@ -429,7 +448,7 @@ def build_api_response( # check for errors if sr.error_msg: raise HTTPException( - status_code=500, + status_code=sr.error_code or HTTP_500_INTERNAL_SERVER_ERROR, detail={ "id": run_id, "url": web_url, diff --git a/routers/paypal.py b/routers/paypal.py index 2d540122a..c59668fc1 100644 --- a/routers/paypal.py +++ b/routers/paypal.py @@ -12,17 +12,13 @@ from loguru import logger from pydantic import BaseModel -from app_users.models import AppUser, PaymentProvider +from app_users.models import PaymentProvider, TransactionReason from daras_ai_v2 import paypal, settings from daras_ai_v2.exceptions import raise_for_status from daras_ai_v2.fastapi_tricks import fastapi_request_json, get_route_url from payments.models import PricingPlan -from payments.tasks import send_monthly_spending_notification_email -from routers.account import ( - paypal_handle_subscription_updated, - payment_processing_route, - account_route, -) +from payments.webhooks import PaypalWebhookHandler, add_balance_for_payment +from routers.account import payment_processing_route, account_route router = APIRouter() @@ -150,36 +146,6 @@ def create_subscription(request: Request, payload: dict = fastapi_request_json): return JSONResponse(content=jsonable_encoder(pp_subscription), status_code=200) -@router.post("/__/paypal/webhook/") -def webhook(request: Request, payload: dict = fastapi_request_json): - if not paypal.verify_webhook_event(payload, headers=request.headers): - logger.error("Invalid PayPal webhook signature") - return JSONResponse({"error": "Invalid signature"}, status_code=400) - - try: - event = WebhookEvent.parse_obj(payload) - except pydantic.ValidationError as e: - logger.error(f"Invalid PayPal webhook payload: {json.dumps(e)}") - return JSONResponse({"error": "Invalid event type"}, status_code=400) - - logger.info(f"Received event: {event.event_type}") - - match event.event_type: - case "PAYMENT.SALE.COMPLETED": - event = SaleCompletedEvent.parse_obj(event) - _handle_sale_completed(event) - case "BILLING.SUBSCRIPTION.ACTIVATED" | "BILLING.SUBSCRIPTION.UPDATED": - event = SubscriptionEvent.parse_obj(event) - paypal_handle_subscription_updated(event.resource) - case "BILLING.SUBSCRIPTION.CANCELLED" | "BILLING.SUBSCRIPTION.EXPIRED": - event = SubscriptionEvent.parse_obj(event) - _handle_subscription_cancelled(event.resource) - case _: - logger.error(f"Unhandled PayPal webhook event: {event.event_type}") - - return JSONResponse({}, status_code=200) - - # Capture payment for the created order to complete the transaction. # @see https://developer.paypal.com/docs/api/orders/v2/#orders_capture @router.post("/__/paypal/orders/{order_id}/capture/") @@ -208,64 +174,39 @@ def _handle_invoice_paid(order_id: str): raise_for_status(response) order = response.json() purchase_unit = order["purchase_units"][0] - uid = purchase_unit["payments"]["captures"][0]["custom_id"] - user = AppUser.objects.get_or_create_from_uid(uid)[0] - user.add_balance( - payment_provider=PaymentProvider.PAYPAL, - invoice_id=order_id, + payment_capture = purchase_unit["payments"]["captures"][0] + add_balance_for_payment( + uid=payment_capture["custom_id"], amount=int(purchase_unit["items"][0]["quantity"]), - charged_amount=int(float(purchase_unit["amount"]["value"]) * 100), - ) - if not user.is_paying: - user.is_paying = True - user.save(update_fields=["is_paying"]) - - -def _handle_sale_completed(event: SaleCompletedEvent): - sale = event.resource - if not sale.billing_agreement_id: - logger.warning(f"Sale {sale.id} is missing subscription ID") - return - - pp_subscription = paypal.Subscription.retrieve(sale.billing_agreement_id) - if not pp_subscription.custom_id: - logger.error(f"Subscription {pp_subscription.id} is missing custom ID") - return - - assert pp_subscription.plan_id, "Subscription is missing plan ID" - plan = PricingPlan.get_by_paypal_plan_id(pp_subscription.plan_id) - if not plan: - logger.error(f"Invalid plan ID: {pp_subscription.plan_id}") - return - - if float(sale.amount.total) == float(plan.monthly_charge): - new_credits = plan.credits - else: - new_credits = int(float(sale.amount.total) * settings.ADDON_CREDITS_PER_DOLLAR) - - user = AppUser.objects.get(uid=pp_subscription.custom_id) - user.add_balance( + invoice_id=payment_capture["id"], payment_provider=PaymentProvider.PAYPAL, - invoice_id=sale.id, - amount=new_credits, - charged_amount=int(float(sale.amount.total) * 100), + charged_amount=int(float(purchase_unit["amount"]["value"]) * 100), ) - if not user.is_paying: - user.is_paying = True - user.save(update_fields=["is_paying"]) - if ( - user.subscription - and user.subscription.should_send_monthly_spending_notification() - ): - send_monthly_spending_notification_email.delay(user.id) - - -def _handle_subscription_cancelled(subscription: paypal.Subscription): - user = AppUser.objects.get(uid=subscription.custom_id) - if ( - user.subscription - and user.subscription.payment_provider == PaymentProvider.PAYPAL - and user.subscription.external_id == subscription.id - ): - user.subscription = None - user.save() + + +@router.post("/__/paypal/webhook") +def webhook(request: Request, payload: dict = fastapi_request_json): + if not paypal.verify_webhook_event(payload, headers=request.headers): + logger.error("Invalid PayPal webhook signature") + return JSONResponse({"error": "Invalid signature"}, status_code=400) + + try: + event = WebhookEvent.parse_obj(payload) + except pydantic.ValidationError as e: + logger.error(f"Invalid PayPal webhook payload: {json.dumps(e)}") + return JSONResponse({"error": "Invalid event type"}, status_code=400) + + logger.info(f"Received event: {event.event_type}") + + match event.event_type: + case "PAYMENT.SALE.COMPLETED": + sale = SaleCompletedEvent.parse_obj(event).resource + PaypalWebhookHandler.handle_sale_completed(sale) + case "BILLING.SUBSCRIPTION.ACTIVATED" | "BILLING.SUBSCRIPTION.UPDATED": + subscription = SubscriptionEvent.parse_obj(event).resource + PaypalWebhookHandler.handle_subscription_updated(subscription) + case "BILLING.SUBSCRIPTION.CANCELLED" | "BILLING.SUBSCRIPTION.EXPIRED": + subscription = SubscriptionEvent.parse_obj(event).resource + PaypalWebhookHandler.handle_subscription_cancelled(subscription) + + return JSONResponse({}) diff --git a/routers/stripe.py b/routers/stripe.py index 4ee3ef12a..3d481d47b 100644 --- a/routers/stripe.py +++ b/routers/stripe.py @@ -1,35 +1,15 @@ -from urllib.parse import quote_plus - import stripe -from django.db import transaction from fastapi import APIRouter, Request -from fastapi.responses import JSONResponse, RedirectResponse +from fastapi.responses import JSONResponse from loguru import logger -from app_users.models import AppUser from daras_ai_v2 import settings -from daras_ai_v2.fastapi_tricks import ( - fastapi_request_body, - get_route_url, -) -from payments.models import PaymentProvider, Subscription -from payments.plans import PricingPlan -from payments.tasks import send_monthly_spending_notification_email -from routers.account import account_route +from daras_ai_v2.fastapi_tricks import fastapi_request_body +from payments.webhooks import StripeWebhookHandler router = APIRouter() -@router.post("/__/stripe/create-portal-session") -def customer_portal(request: Request): - customer = request.user.get_or_create_stripe_customer() - portal_session = stripe.billing_portal.Session.create( - customer=customer, - return_url=get_route_url(account_route), - ) - return RedirectResponse(portal_session.url, status_code=303) - - @router.post("/__/stripe/webhook") def webhook_received(request: Request, payload: bytes = fastapi_request_body): # Retrieve the event by verifying the signature using the raw body and secret if webhook signing is configured. @@ -55,104 +35,17 @@ def webhook_received(request: Request, payload: bytes = fastapi_request_body): status_code=400, ) + logger.info(f"Received event: {event['type']}") + # Get the type of webhook event sent - used to check the status of PaymentIntents. match event["type"]: case "invoice.paid": - handle_invoice_paid(uid, data) + StripeWebhookHandler.handle_invoice_paid(uid, data) case "checkout.session.completed": - _handle_checkout_session_completed(uid, data) + StripeWebhookHandler.handle_checkout_session_completed(uid, data) case "customer.subscription.created" | "customer.subscription.updated": - _handle_subscription_updated(uid, data) + StripeWebhookHandler.handle_subscription_updated(uid, data) case "customer.subscription.deleted": - _handle_subscription_cancelled(uid, data) + StripeWebhookHandler.handle_subscription_cancelled(uid, data) return JSONResponse({"status": "success"}) - - -def handle_invoice_paid(uid: str, invoice_data): - invoice_id = invoice_data.id - line_items = stripe.Invoice._static_request( - "get", - "/v1/invoices/{invoice}/lines".format(invoice=quote_plus(invoice_id)), - ) - user = AppUser.objects.get_or_create_from_uid(uid)[0] - user.add_balance( - payment_provider=PaymentProvider.STRIPE, - invoice_id=invoice_id, - amount=line_items.data[0].quantity, - charged_amount=line_items.data[0].amount, - ) - if not user.is_paying: - user.is_paying = True - user.save(update_fields=["is_paying"]) - if ( - user.subscription - and user.subscription.should_send_monthly_spending_notification() - ): - send_monthly_spending_notification_email.delay(user.id) - - -def _handle_checkout_session_completed(uid: str, session_data): - setup_intent_id = session_data.get("setup_intent") - if not setup_intent_id: - # not a setup mode checkout - return - - # set default payment method - user = AppUser.objects.get_or_create_from_uid(uid)[0] - setup_intent = stripe.SetupIntent.retrieve(setup_intent_id) - subscription_id = setup_intent.metadata.get("subscription_id") - if not ( - user.subscription.payment_provider == PaymentProvider.STRIPE - and user.subscription.external_id == subscription_id - ): - logger.error(f"Subscription {subscription_id} not found for user {user}") - return - - stripe.Subscription.modify( - subscription_id, default_payment_method=setup_intent.payment_method - ) - - -@transaction.atomic -def _handle_subscription_updated(uid: str, subscription_data): - logger.info("Subscription updated") - product = stripe.Product.retrieve(subscription_data.plan.product) - plan = PricingPlan.get_by_stripe_product(product) - if not plan: - raise Exception( - f"PricingPlan not found for product {subscription_data.plan.product}" - ) - - if subscription_data.get("status") != "active": - logger.warning(f"Subscription {subscription_data.id} is not active") - return - - user = AppUser.objects.get_or_create_from_uid(uid)[0] - if user.subscription and ( - user.subscription.payment_provider != PaymentProvider.STRIPE - or user.subscription.external_id != subscription_data.id - ): - logger.warning( - f"User {user} has different existing subscription {user.subscription}. Cancelling that..." - ) - user.subscription.cancel() - user.subscription.delete() - elif not user.subscription: - user.subscription = Subscription() - - user.subscription.plan = plan.db_value - user.subscription.payment_provider = PaymentProvider.STRIPE - user.subscription.external_id = subscription_data.id - - user.subscription.full_clean() - user.subscription.save() - user.save(update_fields=["subscription"]) - - -def _handle_subscription_cancelled(uid: str, subscription_data): - subscription = Subscription.objects.get_by_stripe_subscription_id( - subscription_data.id - ) - logger.info(f"Subscription {subscription} cancelled. Deleting it...") - subscription.delete() diff --git a/server.py b/server.py index 2d8cc5630..c20b82846 100644 --- a/server.py +++ b/server.py @@ -1,10 +1,14 @@ from fastapi.exception_handlers import ( - request_validation_exception_handler, http_exception_handler, + request_validation_exception_handler, ) from fastapi.exceptions import RequestValidationError from starlette.exceptions import HTTPException from starlette.requests import Request +from starlette.status import ( + HTTP_404_NOT_FOUND, + HTTP_405_METHOD_NOT_ALLOWED, +) from daras_ai_v2.pydantic_validation import convert_errors from daras_ai_v2.settings import templates @@ -113,8 +117,8 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE return await request_validation_exception_handler(request, exc) -@app.exception_handler(404) -@app.exception_handler(405) +@app.exception_handler(HTTP_404_NOT_FOUND) +@app.exception_handler(HTTP_405_METHOD_NOT_ALLOWED) async def not_found_exception_handler(request: Request, exc: HTTPException): if not request.headers.get("accept", "").startswith("text/html"): return await http_exception_handler(request, exc) diff --git a/templates/run_complete_email.html b/templates/run_complete_email.html index 613ce3903..c7ae03bb0 100644 --- a/templates/run_complete_email.html +++ b/templates/run_complete_email.html @@ -1,20 +1,14 @@ -

- Your {{ title }} Gooey.AI run completed in {{ run_time_sec }} seconds. -

-

- View output here: {{ app_url }} -

-

- Your prompt: {{ prompt }} -

+

Your {{ recipe_title }} Gooey.AI run completed in {{ run_time_sec }} seconds.

+

View output here: {{ app_url }}

-

- We can’t wait to see what you build with Gooey! -

+{% if prompt %} +

Your prompt: {{ prompt }}

+{% endif %} + +

We can’t wait to see what you build with Gooey!

- Cheers, -
+ Cheers,
The Gooey.AI Team

-{{ "{{{ pm:unsubscribe }}}" }} +{{ "{{{ pm:unsubscribe }}}" }} diff --git a/tests/test_apis.py b/tests/test_apis.py index bd3a915fb..fa897eb83 100644 --- a/tests/test_apis.py +++ b/tests/test_apis.py @@ -15,7 +15,7 @@ @pytest.mark.django_db -def test_apis_sync(mock_gui_runner, force_authentication, threadpool_subtest): +def test_apis_sync(mock_celery_tasks, force_authentication, threadpool_subtest): for page_cls in all_test_pages: threadpool_subtest(_test_api_sync, page_cls) @@ -32,7 +32,7 @@ def _test_api_sync(page_cls: typing.Type[BasePage]): @pytest.mark.django_db -def test_apis_async(mock_gui_runner, force_authentication, threadpool_subtest): +def test_apis_async(mock_celery_tasks, force_authentication, threadpool_subtest): for page_cls in all_test_pages: threadpool_subtest(_test_api_async, page_cls) @@ -65,7 +65,7 @@ def _test_api_async(page_cls: typing.Type[BasePage]): @pytest.mark.django_db -def test_apis_examples(mock_gui_runner, force_authentication, threadpool_subtest): +def test_apis_examples(mock_celery_tasks, force_authentication, threadpool_subtest): qs = ( PublishedRun.objects.exclude(is_approved_example=False) .exclude(published_run_id="") diff --git a/tests/test_checkout.py b/tests/test_checkout.py index 4e412543a..19a988543 100644 --- a/tests/test_checkout.py +++ b/tests/test_checkout.py @@ -3,7 +3,7 @@ from app_users.models import AppUser from daras_ai_v2 import settings -from daras_ai_v2.billing import create_stripe_checkout_session +from daras_ai_v2.billing import stripe_subscription_checkout_redirect from gooey_ui import RedirectException from payments.plans import PricingPlan from server import app @@ -20,4 +20,4 @@ def test_create_checkout_session( return with pytest.raises(RedirectException): - create_stripe_checkout_session(force_authentication, plan) + stripe_subscription_checkout_redirect(force_authentication, plan) diff --git a/tests/test_integrations_api.py b/tests/test_integrations_api.py index c6f11c9d8..398fdd52a 100644 --- a/tests/test_integrations_api.py +++ b/tests/test_integrations_api.py @@ -11,7 +11,7 @@ @pytest.mark.django_db -def test_send_msg_streaming(mock_gui_runner, force_authentication): +def test_send_msg_streaming(mock_celery_tasks, force_authentication): r = client.post( "/v3/integrations/stream/", json={