diff --git a/app_users/admin.py b/app_users/admin.py
index da50b57e9..caa61d223 100644
--- a/app_users/admin.py
+++ b/app_users/admin.py
@@ -192,19 +192,24 @@ class AppUserTransactionAdmin(admin.ModelAdmin):
"invoice_id",
"user",
"amount",
+ "dollar_amount",
"end_balance",
"payment_provider",
- "dollar_amount",
+ "reason",
+ "plan",
"created_at",
]
- readonly_fields = ["created_at"]
+ readonly_fields = ["view_payment_provider_url", "created_at"]
list_filter = [
- "created_at",
+ "reason",
("payment_provider", admin.EmptyFieldListFilter),
"payment_provider",
+ "plan",
+ "created_at",
]
inlines = [SavedRunInline]
ordering = ["-created_at"]
+ search_fields = ["invoice_id"]
@admin.display(description="Charged Amount")
def dollar_amount(self, obj: models.AppUserTransaction):
@@ -212,6 +217,14 @@ def dollar_amount(self, obj: models.AppUserTransaction):
return
return f"${obj.charged_amount / 100}"
+ @admin.display(description="Payment Provider URL")
+ def view_payment_provider_url(self, txn: models.AppUserTransaction):
+ url = txn.payment_provider_url()
+ if url:
+ return open_in_new_tab(url, label=url)
+ else:
+ raise txn.DoesNotExist
+
@admin.register(LogEntry)
class LogEntryAdmin(admin.ModelAdmin):
diff --git a/app_users/migrations/0018_appusertransaction_plan_appusertransaction_reason.py b/app_users/migrations/0018_appusertransaction_plan_appusertransaction_reason.py
new file mode 100644
index 000000000..2891795fd
--- /dev/null
+++ b/app_users/migrations/0018_appusertransaction_plan_appusertransaction_reason.py
@@ -0,0 +1,59 @@
+# Generated by Django 4.2.7 on 2024-07-14 20:51
+
+from django.db import migrations, models
+
+
+def forwards_func(apps, schema_editor):
+ from payments.plans import PricingPlan
+ from app_users.models import TransactionReason
+
+ # We get the model from the versioned app registry;
+ # if we directly import it, it'll be the wrong version
+ AppUserTransaction = apps.get_model("app_users", "AppUserTransaction")
+ db_alias = schema_editor.connection.alias
+ objects = AppUserTransaction.objects.using(db_alias)
+
+ for transaction in objects.all():
+ if transaction.amount <= 0:
+ transaction.reason = TransactionReason.DEDUCT
+ else:
+ # For old transactions, we didn't have a subscription field.
+ # It just so happened that all monthly subscriptions we offered had
+ # different amounts from the one-time purchases.
+ # This uses that heuristic to determine whether a transaction
+ # was a subscription payment or a one-time purchase.
+ transaction.reason = TransactionReason.ADDON
+ for plan in PricingPlan:
+ if (
+ transaction.amount == plan.credits
+ and transaction.charged_amount == plan.monthly_charge * 100
+ ):
+ transaction.plan = plan.db_value
+ transaction.reason = TransactionReason.SUBSCRIBE
+ transaction.save(update_fields=["reason", "plan"])
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('app_users', '0017_alter_appuser_subscription'),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name='appusertransaction',
+ name='plan',
+ field=models.IntegerField(blank=True, choices=[(1, 'Basic Plan'), (2, 'Premium Plan'), (3, 'Starter'), (4, 'Creator'), (5, 'Business'), (6, 'Enterprise / Agency')], default=None, help_text="User's plan at the time of this transaction.", null=True),
+ ),
+ migrations.AddField(
+ model_name='appusertransaction',
+ name='reason',
+ field=models.IntegerField(choices=[(1, 'Deduct'), (2, 'Addon'), (3, 'Subscribe'), (4, 'Sub-Create'), (5, 'Sub-Cycle'), (6, 'Sub-Update'), (7, 'Auto-Recharge')], default=0, help_text='The reason for this transaction.
Deduct: Credits deducted due to a run.
Addon: User purchased an add-on.
Subscribe: Applies to subscriptions where no distinction was made between create, update and cycle.
Sub-Create: A subscription was created.
Sub-Cycle: A subscription advanced into a new period.
Sub-Update: A subscription was updated.
Auto-Recharge: Credits auto-recharged due to low balance.'),
+ ),
+ migrations.RunPython(forwards_func, migrations.RunPython.noop),
+ migrations.AlterField(
+ model_name='appusertransaction',
+ name='reason',
+ field=models.IntegerField(choices=[(1, 'Deduct'), (2, 'Addon'), (3, 'Subscribe'), (4, 'Sub-Create'), (5, 'Sub-Cycle'), (6, 'Sub-Update'), (7, 'Auto-Recharge')], help_text='The reason for this transaction.
Deduct: Credits deducted due to a run.
Addon: User purchased an add-on.
Subscribe: Applies to subscriptions where no distinction was made between create, update and cycle.
Sub-Create: A subscription was created.
Sub-Cycle: A subscription advanced into a new period.
Sub-Update: A subscription was updated.
Auto-Recharge: Credits auto-recharged due to low balance.'),
+ ),
+ ]
diff --git a/app_users/models.py b/app_users/models.py
index f4b29f490..1e1016520 100644
--- a/app_users/models.py
+++ b/app_users/models.py
@@ -4,6 +4,7 @@
from django.db.models import Sum
from django.utils import timezone
from firebase_admin import auth
+from furl import furl
from phonenumber_field.modelfields import PhoneNumberField
from bots.custom_fields import CustomURLField, StrippedTextField
@@ -172,6 +173,7 @@ def add_balance(
user: AppUser = AppUser.objects.select_for_update().get(pk=self.pk)
user.balance += amount
user.save(update_fields=["balance"])
+ kwargs.setdefault("plan", user.subscription and user.subscription.plan)
return AppUserTransaction.objects.create(
user=self,
invoice_id=invoice_id,
@@ -273,6 +275,18 @@ def get_dollars_spent_this_month(self) -> float:
return (cents_spent or 0) / 100
+class TransactionReason(models.IntegerChoices):
+ DEDUCT = 1, "Deduct"
+ ADDON = 2, "Addon"
+
+ SUBSCRIBE = 3, "Subscribe"
+ SUBSCRIPTION_CREATE = 4, "Sub-Create"
+ SUBSCRIPTION_CYCLE = 5, "Sub-Cycle"
+ SUBSCRIPTION_UPDATE = 6, "Sub-Update"
+
+ AUTO_RECHARGE = 7, "Auto-Recharge"
+
+
class AppUserTransaction(models.Model):
user = models.ForeignKey(
"AppUser", on_delete=models.CASCADE, related_name="transactions"
@@ -307,6 +321,25 @@ class AppUserTransaction(models.Model):
default=0,
)
+ reason = models.IntegerField(
+ choices=TransactionReason.choices,
+ help_text="The reason for this transaction.
"
+ f"{TransactionReason.DEDUCT.label}: Credits deducted due to a run.
"
+ f"{TransactionReason.ADDON.label}: User purchased an add-on.
"
+ f"{TransactionReason.SUBSCRIBE.label}: Applies to subscriptions where no distinction was made between create, update and cycle.
"
+ f"{TransactionReason.SUBSCRIPTION_CREATE.label}: A subscription was created.
"
+ f"{TransactionReason.SUBSCRIPTION_CYCLE.label}: A subscription advanced into a new period.
"
+ f"{TransactionReason.SUBSCRIPTION_UPDATE.label}: A subscription was updated.
"
+ f"{TransactionReason.AUTO_RECHARGE.label}: Credits auto-recharged due to low balance.",
+ )
+ plan = models.IntegerField(
+ choices=PricingPlan.db_choices(),
+ help_text="User's plan at the time of this transaction.",
+ null=True,
+ blank=True,
+ default=None,
+ )
+
created_at = models.DateTimeField(editable=False, blank=True, default=timezone.now)
class Meta:
@@ -320,32 +353,41 @@ class Meta:
def __str__(self):
return f"{self.invoice_id} ({self.amount})"
- def get_subscription_plan(self) -> PricingPlan | None:
- """
- It just so happened that all monthly subscriptions we offered had
- different amounts from the one-time purchases.
- This uses that heuristic to determine whether a transaction
- was a subscription payment or a one-time purchase.
-
- TODO: Implement this more robustly
- """
- if self.amount <= 0:
- # credits deducted
- return None
-
- for plan in PricingPlan:
- if (
- self.amount == plan.credits
- and self.charged_amount == plan.monthly_charge * 100
+ def save(self, *args, **kwargs):
+ if self.reason is None:
+ if self.amount <= 0:
+ self.reason = TransactionReason.DEDUCT
+ else:
+ self.reason = TransactionReason.ADDON
+ super().save(*args, **kwargs)
+
+ def reason_note(self) -> str:
+ match self.reason:
+ case (
+ TransactionReason.SUBSCRIPTION_CREATE
+ | TransactionReason.SUBSCRIPTION_CYCLE
+ | TransactionReason.SUBSCRIPTION_UPDATE
+ | TransactionReason.SUBSCRIBE
):
- return plan
-
- return None
-
- def note(self) -> str:
- if self.amount <= 0:
- return ""
- elif plan := self.get_subscription_plan():
- return f"Subscription payment: {plan.title} (+{self.amount:,} credits)"
- else:
- return f"Addon purchase (+{self.amount:,} credits)"
+ ret = "Subscription payment"
+ if self.plan:
+ ret += f": {PricingPlan.from_db_value(self.plan).title}"
+ return ret
+ case TransactionReason.AUTO_RECHARGE:
+ return "Auto recharge"
+ case TransactionReason.ADDON:
+ return "Addon purchase"
+ case TransactionReason.DEDUCT:
+ return "Run deduction"
+
+ def payment_provider_url(self) -> str | None:
+ match self.payment_provider:
+ case PaymentProvider.STRIPE:
+ return str(
+ furl("https://dashboard.stripe.com/invoices/") / self.invoice_id
+ )
+ case PaymentProvider.PAYPAL:
+ return str(
+ furl("https://www.paypal.com/unifiedtransactions/details/payment/")
+ / self.invoice_id
+ )
diff --git a/bots/admin.py b/bots/admin.py
index 82da0aab2..e154f03b2 100644
--- a/bots/admin.py
+++ b/bots/admin.py
@@ -31,7 +31,7 @@
Workflow,
)
from bots.tasks import create_personal_channels_for_all_members
-from celeryapp.tasks import gui_runner
+from celeryapp.tasks import runner_task
from daras_ai_v2.fastapi_tricks import get_route_url
from gooeysite.custom_actions import export_to_excel, export_to_csv
from gooeysite.custom_filters import (
diff --git a/bots/migrations/0077_savedrun_error_code_savedrun_error_type_and_more.py b/bots/migrations/0077_savedrun_error_code_savedrun_error_type_and_more.py
new file mode 100644
index 000000000..44033b955
--- /dev/null
+++ b/bots/migrations/0077_savedrun_error_code_savedrun_error_type_and_more.py
@@ -0,0 +1,28 @@
+# Generated by Django 4.2.7 on 2024-07-12 19:30
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('bots', '0076_alter_workflowmetadata_default_image_and_more'),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name='savedrun',
+ name='error_code',
+ field=models.IntegerField(blank=True, default=None, help_text='The HTTP status code of the error. If this is not set, 500 is assumed.', null=True),
+ ),
+ migrations.AddField(
+ model_name='savedrun',
+ name='error_type',
+ field=models.TextField(blank=True, default='', help_text='The exception type'),
+ ),
+ migrations.AlterField(
+ model_name='savedrun',
+ name='error_msg',
+ field=models.TextField(blank=True, default='', help_text='The error message. If this is not set, the run is deemed successful.'),
+ ),
+ ]
diff --git a/bots/models.py b/bots/models.py
index 6ff6674ff..757766a4c 100644
--- a/bots/models.py
+++ b/bots/models.py
@@ -212,10 +212,24 @@ class SavedRun(models.Model):
state = models.JSONField(default=dict, blank=True, encoder=PostgresJSONEncoder)
- error_msg = models.TextField(default="", blank=True)
+ error_msg = models.TextField(
+ default="",
+ blank=True,
+ help_text="The error message. If this is not set, the run is deemed successful.",
+ )
run_time = models.DurationField(default=datetime.timedelta, blank=True)
run_status = models.TextField(default="", blank=True)
+ error_code = models.IntegerField(
+ null=True,
+ default=None,
+ blank=True,
+ help_text="The HTTP status code of the error. If this is not set, 500 is assumed.",
+ )
+ error_type = models.TextField(
+ default="", blank=True, help_text="The exception type"
+ )
+
hidden = models.BooleanField(default=False)
is_flagged = models.BooleanField(default=False)
@@ -282,9 +296,12 @@ def __str__(self):
def parent_published_run(self) -> typing.Optional["PublishedRun"]:
return self.parent_version and self.parent_version.published_run
- def get_app_url(self):
+ def get_app_url(self, query_params: dict = None):
return Workflow(self.workflow).page_cls.app_url(
- example_id=self.example_id, run_id=self.run_id, uid=self.uid
+ example_id=self.example_id,
+ run_id=self.run_id,
+ uid=self.uid,
+ query_params=query_params,
)
def to_dict(self) -> dict:
@@ -1624,9 +1641,9 @@ def duplicate(
visibility=visibility,
)
- def get_app_url(self):
+ def get_app_url(self, query_params: dict = None):
return Workflow(self.workflow).page_cls.app_url(
- example_id=self.published_run_id
+ example_id=self.published_run_id, query_params=query_params
)
def add_version(
diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py
index b2a7b4327..2e30e4379 100644
--- a/celeryapp/tasks.py
+++ b/celeryapp/tasks.py
@@ -14,32 +14,34 @@
import gooey_ui as st
from app_users.models import AppUser, AppUserTransaction
from bots.admin_links import change_obj_url
-from bots.models import SavedRun, Platform
+from bots.models import SavedRun, Platform, Workflow
from celeryapp.celeryconfig import app
from daras_ai.image_input import truncate_text_words
from daras_ai_v2 import settings
-from daras_ai_v2.auto_recharge import auto_recharge_user
from daras_ai_v2.base import StateKeys, BasePage
from daras_ai_v2.exceptions import UserError
-from daras_ai_v2.redis_cache import redis_lock
from daras_ai_v2.send_email import send_email_via_postmark, send_low_balance_email
from daras_ai_v2.settings import templates
from gooey_ui.pubsub import realtime_push
from gooey_ui.state import set_query_params
from gooeysite.bg_db_conn import db_middleware
+from payments.auto_recharge import (
+ should_attempt_auto_recharge,
+ run_auto_recharge_gracefully,
+)
DEFAULT_RUN_STATUS = "Running..."
@app.task
-def gui_runner(
+def runner_task(
*,
page_cls: typing.Type[BasePage],
user_id: int,
run_id: str,
uid: str,
channel: str,
-):
+) -> int:
start_time = time()
error_msg = None
@@ -89,34 +91,50 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False
save_on_step()
for val in page.main(sr, st.session_state):
save_on_step(val)
+
# render errors nicely
except Exception as e:
- if isinstance(e, HTTPException) and e.status_code == 402:
- error_msg = page.generate_credit_error_message(run_id, uid)
- try:
- raise UserError(error_msg) from e
- except UserError as e:
- sentry_sdk.capture_exception(e, level=e.sentry_level)
+ if isinstance(e, UserError):
+ sentry_level = e.sentry_level
else:
- if isinstance(e, UserError):
- sentry_level = e.sentry_level
- else:
- sentry_level = "error"
- traceback.print_exc()
- sentry_sdk.capture_exception(e, level=sentry_level)
- error_msg = err_msg_for_exc(e)
+ sentry_level = "error"
+ traceback.print_exc()
+ sentry_sdk.capture_exception(e, level=sentry_level)
+ error_msg = err_msg_for_exc(e)
+ sr.error_type = type(e).__qualname__
+ sr.error_code = getattr(e, "status_code", None)
+
# run completed successfully, deduct credits
else:
sr.transaction, sr.price = page.deduct_credits(st.session_state)
+
+ # save everything, mark run as completed
finally:
save_on_step(done=True)
- if not sr.is_api_call:
- send_email_on_completion(page, sr)
- run_low_balance_email_check(user)
+
+ return sr.id
+
+
+@app.task
+def post_runner_tasks(saved_run_id: int):
+ sr = SavedRun.objects.get(id=saved_run_id)
+ user = AppUser.objects.get(uid=sr.uid)
+
+ if not sr.is_api_call:
+ send_email_on_completion(sr)
+
+ if should_attempt_auto_recharge(user):
+ run_auto_recharge_gracefully(user)
+
+ run_low_balance_email_check(user)
def err_msg_for_exc(e: Exception):
- if isinstance(e, requests.HTTPError):
+ if isinstance(e, UserError):
+ return e.message
+ elif isinstance(e, HTTPException):
+ return f"(HTTP {e.status_code}) {e.detail})"
+ elif isinstance(e, requests.HTTPError):
response: requests.Response = e.response
try:
err_body = response.json()
@@ -133,10 +151,6 @@ def err_msg_for_exc(e: Exception):
return f"(GPU) {err_type}: {err_str}"
err_str = str(err_body)
return f"(HTTP {response.status_code}) {html.escape(err_str[:1000])}"
- elif isinstance(e, HTTPException):
- return f"(HTTP {e.status_code}) {e.detail})"
- elif isinstance(e, UserError):
- return e.message
else:
return f"{type(e).__name__}: {e}"
@@ -177,7 +191,7 @@ def run_low_balance_email_check(user: AppUser):
user.save(update_fields=["low_balance_email_sent_at"])
-def send_email_on_completion(page: BasePage, sr: SavedRun):
+def send_email_on_completion(sr: SavedRun):
run_time_sec = sr.run_time.total_seconds()
if (
run_time_sec <= settings.SEND_RUN_EMAIL_AFTER_SEC
@@ -189,9 +203,16 @@ def send_email_on_completion(page: BasePage, sr: SavedRun):
)
if not to_address:
return
- prompt = (page.preview_input(sr.state) or "").strip()
- title = (sr.state.get("__title") or page.title).strip()
- subject = f"🌻 “{truncate_text_words(prompt, maxlen=50)}” {title} is done"
+
+ workflow = Workflow(sr.workflow)
+ page_cls = workflow.page_cls
+ prompt = (page_cls.preview_input(sr.state) or "").strip().replace("\n", " ")
+ recipe_title = page_cls.get_recipe_title()
+
+ subject = (
+ f"🌻 “{truncate_text_words(prompt, maxlen=50) or 'Run'}” {recipe_title} is done"
+ )
+
send_email_via_postmark(
from_address=settings.SUPPORT_EMAIL,
to_address=to_address,
@@ -200,7 +221,7 @@ def send_email_on_completion(page: BasePage, sr: SavedRun):
run_time_sec=round(run_time_sec),
app_url=sr.get_app_url(),
prompt=prompt,
- title=title,
+ recipe_title=recipe_title,
),
message_stream="gooey-ai-workflows",
)
@@ -221,11 +242,3 @@ def send_integration_attempt_email(*, user_id: int, platform: Platform, run_url:
subject=f"{user.display_name} Attempted to Connect to {platform.label}",
html_body=html_body,
)
-
-
-@app.task
-def auto_recharge(*, user_id: int):
- redis_lock_key = f"gooey/auto_recharge/{user_id}"
- with redis_lock(redis_lock_key):
- user = AppUser.objects.get(id=user_id)
- auto_recharge_user(user)
diff --git a/conftest.py b/conftest.py
index 96fb837f1..a38c6a11a 100644
--- a/conftest.py
+++ b/conftest.py
@@ -51,16 +51,17 @@ def force_authentication():
@pytest.fixture
-def mock_gui_runner():
+def mock_celery_tasks():
with (
- patch("celeryapp.tasks.gui_runner", _mock_gui_runner),
+ patch("celeryapp.tasks.runner_task", _mock_runner_task),
+ patch("celeryapp.tasks.post_runner_tasks", _mock_post_runner_tasks),
patch("daras_ai_v2.bots.realtime_subscribe", _mock_realtime_subscribe),
):
yield
@app.task
-def _mock_gui_runner(
+def _mock_runner_task(
*, page_cls: typing.Type[BasePage], run_id: str, uid: str, **kwargs
):
sr = page_cls.run_doc_sr(run_id, uid)
@@ -70,6 +71,11 @@ def _mock_gui_runner(
_mock_realtime_push(channel, sr.to_dict())
+@app.task
+def _mock_post_runner_tasks(*args, **kwargs):
+ pass
+
+
def _mock_realtime_push(channel, value):
redis_qs[channel].put(value)
diff --git a/daras_ai_v2/auto_recharge.py b/daras_ai_v2/auto_recharge.py
deleted file mode 100644
index 15c47e42f..000000000
--- a/daras_ai_v2/auto_recharge.py
+++ /dev/null
@@ -1,64 +0,0 @@
-from loguru import logger
-
-from app_users.models import AppUser, PaymentProvider
-from payments.tasks import send_email_budget_reached, send_email_auto_recharge_failed
-
-
-def auto_recharge_user(user: AppUser):
- if not user_should_auto_recharge(user):
- logger.info(f"User doesn't need to auto-recharge: {user=}")
- return
-
- dollars_spent = user.get_dollars_spent_this_month()
- if (
- dollars_spent + user.subscription.auto_recharge_topup_amount
- > user.subscription.monthly_spending_budget
- ):
- if not user.subscription.has_sent_monthly_budget_email_this_month():
- send_email_budget_reached.delay(user.id)
- logger.info(f"User has reached the monthly budget: {user=}, {dollars_spent=}")
- return
-
- match user.subscription.payment_provider:
- case PaymentProvider.STRIPE:
- customer = user.search_stripe_customer()
- if not customer:
- logger.error(f"User doesn't have a stripe customer: {user=}")
- return
-
- try:
- invoice = user.subscription.stripe_get_or_create_auto_invoice(
- amount_in_dollars=user.subscription.auto_recharge_topup_amount,
- metadata_key="auto_recharge",
- )
-
- if invoice.status == "open":
- pm = user.subscription.stripe_get_default_payment_method()
- invoice.pay(payment_method=pm)
- logger.info(
- f"Payment attempted for auto recharge invoice: {user=}, {invoice=}"
- )
- elif invoice.status == "paid":
- logger.info(
- f"Auto recharge invoice already paid recently: {user=}, {invoice=}"
- )
- except Exception as e:
- logger.error(
- f"Error while auto-recharging user: {user=}, {e=}, {invoice=}"
- )
- send_email_auto_recharge_failed.delay(user.id)
-
- case PaymentProvider.PAYPAL:
- logger.error(f"Auto-recharge not supported for PayPal: {user=}")
- return
-
-
-def user_should_auto_recharge(user: AppUser):
- """
- whether an auto recharge should be attempted for the user
- """
- return (
- user.subscription
- and user.subscription.auto_recharge_enabled
- and user.balance < user.subscription.auto_recharge_balance_threshold
- )
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index cfb731d4c..a447b4252 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -38,7 +38,6 @@
from daras_ai.text_format import format_number_with_suffix
from daras_ai_v2 import settings, urls
from daras_ai_v2.api_examples_widget import api_example_generator
-from daras_ai_v2.auto_recharge import user_should_auto_recharge
from daras_ai_v2.breadcrumbs import render_breadcrumbs, get_title_breadcrumbs
from daras_ai_v2.copy_to_clipboard_button_widget import (
copy_to_clipboard_button,
@@ -49,6 +48,7 @@
from daras_ai_v2.db import (
ANONYMOUS_USER_COOKIE,
)
+from daras_ai_v2.exceptions import InsufficientCredits
from daras_ai_v2.fastapi_tricks import get_route_path
from daras_ai_v2.grid_layout_widget import grid_layout
from daras_ai_v2.html_spinner_widget import html_spinner
@@ -83,6 +83,10 @@
from gooey_ui.components.modal import Modal
from gooey_ui.components.pills import pill
from gooey_ui.pubsub import realtime_pull
+from payments.auto_recharge import (
+ should_attempt_auto_recharge,
+ run_auto_recharge_gracefully,
+)
from routers.account import AccountTabs
from routers.root import RecipeTabs
@@ -141,7 +145,6 @@ class BasePage:
class RequestModel(BaseModel):
functions: list[RecipeFunction] | None = Field(
- None,
title="🧩 Functions",
)
variables: dict[str, typing.Any] = Field(
@@ -1411,6 +1414,8 @@ def render_usage_guide(self):
raise NotImplementedError
def main(self, sr: SavedRun, state: dict) -> typing.Iterator[str | None]:
+ yield from self.ensure_credits_and_auto_recharge(sr, state)
+
yield from call_recipe_functions(
saved_run=sr,
current_user=self.request.user,
@@ -1437,15 +1442,12 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
response = self.ResponseModel.construct()
# run the recipe
- gen = self.run_v2(request, response)
- while True:
- try:
- val = next(gen)
- except StopIteration:
- break
- finally:
+ try:
+ for val in self.run_v2(request, response):
state.update(response.dict(exclude_unset=True))
- yield val
+ yield val
+ finally:
+ state.update(response.dict(exclude_unset=True))
# validate the response if successful
self.ResponseModel.validate(response)
@@ -1595,8 +1597,6 @@ def estimate_run_duration(self) -> int | None:
pass
def on_submit(self):
- from celeryapp.tasks import auto_recharge
-
try:
sr = self.create_new_run(enable_rate_limits=True)
except ValidationError as e:
@@ -1608,15 +1608,7 @@ def on_submit(self):
st.session_state[StateKeys.error_msg] = e.detail.get("error", "")
return
- if user_should_auto_recharge(self.request.user):
- auto_recharge.delay(user_id=self.request.user.id)
-
- if settings.CREDITS_TO_DEDUCT_PER_RUN and not self.check_credits():
- sr.run_status = ""
- sr.error_msg = self.generate_credit_error_message(sr.run_id, sr.uid)
- sr.save(update_fields=["run_status", "error_msg"])
- else:
- self.call_runner_task(sr)
+ self.call_runner_task(sr)
raise RedirectException(self.app_url(run_id=sr.run_id, uid=sr.uid))
@@ -1689,15 +1681,19 @@ def dump_state_to_sr(self, state: dict, sr: SavedRun):
)
def call_runner_task(self, sr: SavedRun):
- from celeryapp.tasks import gui_runner
-
- return gui_runner.delay(
- page_cls=self.__class__,
- user_id=self.request.user.id,
- run_id=sr.run_id,
- uid=sr.uid,
- channel=self.realtime_channel_name(sr.run_id, sr.uid),
+ from celeryapp.tasks import runner_task, post_runner_tasks
+
+ chain = (
+ runner_task.s(
+ page_cls=self.__class__,
+ user_id=self.request.user.id,
+ run_id=sr.run_id,
+ uid=sr.uid,
+ channel=self.realtime_channel_name(sr.run_id, sr.uid),
+ )
+ | post_runner_tasks.s()
)
+ return chain.apply_async()
@classmethod
def realtime_channel_name(cls, run_id, uid):
@@ -2073,10 +2069,27 @@ def run_as_api_tab(self):
manage_api_keys(self.request.user)
- def check_credits(self) -> bool:
+ def ensure_credits_and_auto_recharge(self, sr: SavedRun, state: dict):
+ if not settings.CREDITS_TO_DEDUCT_PER_RUN:
+ return
assert self.request, "request must be set to check credits"
assert self.request.user, "request.user must be set to check credits"
- return self.request.user.balance >= self.get_price_roundoff(st.session_state)
+
+ user = self.request.user
+ price = self.get_price_roundoff(state)
+
+ if user.balance >= price:
+ return
+
+ if should_attempt_auto_recharge(user):
+ yield "Low balance detected. Recharging..."
+ run_auto_recharge_gracefully(user)
+ user.refresh_from_db()
+
+ if user.balance >= price:
+ return
+
+ raise InsufficientCredits(self.request.user, sr)
def deduct_credits(self, state: dict) -> tuple[AppUserTransaction, int]:
assert (
diff --git a/daras_ai_v2/billing.py b/daras_ai_v2/billing.py
index 1e9b2f764..709fde926 100644
--- a/daras_ai_v2/billing.py
+++ b/daras_ai_v2/billing.py
@@ -1,3 +1,5 @@
+from typing import Literal
+
import stripe
from django.core.exceptions import ValidationError
@@ -13,10 +15,14 @@
from gooey_ui.components.pills import pill
from payments.models import PaymentMethodSummary
from payments.plans import PricingPlan
+from scripts.migrate_existing_subscriptions import available_subscriptions
rounded_border = "w-100 border shadow-sm rounded py-4 px-3"
+PlanActionLabel = Literal["Upgrade", "Downgrade", "Contact Us", "Your Plan"]
+
+
def billing_page(user: AppUser):
render_payments_setup()
@@ -27,14 +33,15 @@ def billing_page(user: AppUser):
render_credit_balance(user)
with st.div(className="my-5"):
- render_all_plans(user)
+ selected_payment_provider = render_all_plans(user)
+
+ with st.div(className="my-5"):
+ render_addon_section(user, selected_payment_provider)
if user.subscription and user.subscription.payment_provider:
if user.subscription.payment_provider == PaymentProvider.STRIPE:
with st.div(className="my-5"):
render_auto_recharge_section(user)
- with st.div(className="my-5"):
- render_addon_section(user)
with st.div(className="my-5"):
render_payment_information(user)
@@ -115,7 +122,7 @@ def render_credit_balance(user: AppUser):
)
-def render_all_plans(user: AppUser):
+def render_all_plans(user: AppUser) -> PaymentProvider:
current_plan = (
PricingPlan.from_sub(user.subscription)
if user.subscription
@@ -126,11 +133,11 @@ def render_all_plans(user: AppUser):
st.write("## All Plans")
plans_div = st.div(className="mb-1")
- if user.subscription:
- payment_provider = None
+ if user.subscription and user.subscription.payment_provider:
+ selected_payment_provider = None
else:
with st.div():
- payment_provider = PaymentProvider[
+ selected_payment_provider = PaymentProvider[
payment_provider_radio() or PaymentProvider.STRIPE.name
]
@@ -144,7 +151,9 @@ def _render_plan(plan: PricingPlan):
className=f"{rounded_border} flex-grow-1 d-flex flex-column p-3 mb-2 {extra_class}"
):
_render_plan_details(plan)
- _render_plan_action_button(user, plan, current_plan, payment_provider)
+ _render_plan_action_button(
+ user, plan, current_plan, selected_payment_provider
+ )
with plans_div:
grid_layout(4, all_plans, _render_plan, separator=False)
@@ -152,6 +161,8 @@ def _render_plan(plan: PricingPlan):
with st.div(className="my-2 d-flex justify-content-center"):
st.caption(f"**[See all features & benefits]({settings.PRICING_DETAILS_URL})**")
+ return selected_payment_provider
+
def _render_plan_details(plan: PricingPlan):
with st.div(className="flex-grow-1"):
@@ -188,33 +199,47 @@ def _render_plan_action_button(
className=btn_classes + " btn btn-theme btn-primary",
):
st.html("Contact Us")
- elif current_plan is not PricingPlan.ENTERPRISE:
- update_subscription_button(
- user=user,
- plan=plan,
- current_plan=current_plan,
- className=btn_classes,
- payment_provider=payment_provider,
- )
+ elif user.subscription and not user.subscription.payment_provider:
+ # don't show upgrade/downgrade buttons for enterprise customers
+ # assumption: anyone without a payment provider attached is admin/enterprise
+ return
+ else:
+ if plan.credits > current_plan.credits:
+ label, btn_type = ("Upgrade", "primary")
+ else:
+ label, btn_type = ("Downgrade", "secondary")
+
+ if user.subscription and user.subscription.payment_provider:
+ # subscription exists, show upgrade/downgrade button
+ _render_update_subscription_button(
+ label,
+ user=user,
+ current_plan=current_plan,
+ plan=plan,
+ className=f"{btn_classes} btn btn-theme btn-{btn_type}",
+ )
+ else:
+ assert payment_provider is not None # for sanity
+ _render_create_subscription_button(
+ label,
+ btn_type=btn_type,
+ user=user,
+ plan=plan,
+ payment_provider=payment_provider,
+ )
-def update_subscription_button(
+def _render_create_subscription_button(
+ label: PlanActionLabel,
*,
+ btn_type: str,
user: AppUser,
- current_plan: PricingPlan,
plan: PricingPlan,
- className: str = "",
- payment_provider: PaymentProvider | None = None,
+ payment_provider: PaymentProvider,
):
- if plan.credits > current_plan.credits:
- label, btn_type = ("Upgrade", "primary")
- else:
- label, btn_type = ("Downgrade", "secondary")
- className += f" btn btn-theme btn-{btn_type}"
-
- key = f"change-sub-{plan.key}"
match payment_provider:
case PaymentProvider.STRIPE:
+ key = f"stripe-sub-{plan.key}"
render_stripe_subscription_button(
user=user,
label=label,
@@ -224,7 +249,19 @@ def update_subscription_button(
)
case PaymentProvider.PAYPAL:
render_paypal_subscription_button(plan=plan)
- case _ if label == "Downgrade":
+
+
+def _render_update_subscription_button(
+ label: PlanActionLabel,
+ *,
+ user: AppUser,
+ current_plan: PricingPlan,
+ plan: PricingPlan,
+ className: str = "",
+):
+ key = f"change-sub-{plan.key}"
+ match label:
+ case "Downgrade":
downgrade_modal = Modal(
"Confirm downgrade",
key=f"downgrade-plan-modal-{plan.key}",
@@ -260,7 +297,12 @@ def update_subscription_button(
downgrade_modal.close()
case _:
if st.button(label, className=className, key=key):
- change_subscription(user, plan)
+ change_subscription(
+ user,
+ plan,
+ # when upgrading, charge the full new amount today: https://docs.stripe.com/billing/subscriptions/billing-cycle#reset-the-billing-cycle-to-the-current-time
+ billing_cycle_anchor="now",
+ )
def fmt_price(plan: PricingPlan) -> str:
@@ -270,7 +312,7 @@ def fmt_price(plan: PricingPlan) -> str:
return "Free"
-def change_subscription(user: AppUser, new_plan: PricingPlan):
+def change_subscription(user: AppUser, new_plan: PricingPlan, **kwargs):
from routers.account import account_route
from routers.account import payment_processing_route
@@ -301,9 +343,7 @@ def change_subscription(user: AppUser, new_plan: PricingPlan):
metadata={
settings.STRIPE_USER_SUBSCRIPTION_METADATA_FIELD: new_plan.key,
},
- # charge the full new amount today, without prorations
- # see: https://docs.stripe.com/billing/subscriptions/billing-cycle#reset-the-billing-cycle-to-the-current-time
- billing_cycle_anchor="now",
+ **kwargs,
proration_behavior="none",
)
raise RedirectException(
@@ -337,44 +377,57 @@ def payment_provider_radio(**props) -> str | None:
)
-def render_addon_section(user: AppUser):
- assert user.subscription
-
- st.write("# Purchase More Credits")
+def render_addon_section(user: AppUser, selected_payment_provider: PaymentProvider):
+ if user.subscription:
+ st.write("# Purchase More Credits")
+ else:
+ st.write("# Purchase Credits")
st.caption(f"Buy more credits. $1 per {settings.ADDON_CREDITS_PER_DOLLAR} credits")
- provider = PaymentProvider(user.subscription.payment_provider)
+ if user.subscription:
+ provider = PaymentProvider(user.subscription.payment_provider)
+ else:
+ provider = selected_payment_provider
match provider:
case PaymentProvider.STRIPE:
- for amount in settings.ADDON_AMOUNT_CHOICES:
- render_stripe_addon_button(amount, user=user)
+ render_stripe_addon_buttons(user)
case PaymentProvider.PAYPAL:
- for amount in settings.ADDON_AMOUNT_CHOICES:
- render_paypal_addon_button(amount)
- st.div(
- id="paypal-addon-buttons",
- className="mt-2",
- style={"width": "fit-content"},
- )
- st.div(id="paypal-result-message")
+ render_paypal_addon_buttons()
-def render_paypal_addon_button(amount: int):
- st.html(
- f"""
-
- """
+def render_paypal_addon_buttons():
+ selected_amt = st.horizontal_radio(
+ "",
+ settings.ADDON_AMOUNT_CHOICES,
+ format_func=lambda amt: f"${amt:,}",
+ checked_by_default=False,
)
+ if selected_amt:
+ st.js(
+ f"setPaypalAddonQuantity({int(selected_amt) * settings.ADDON_CREDITS_PER_DOLLAR})"
+ )
+ st.div(
+ id="paypal-addon-buttons",
+ className="mt-2",
+ style={"width": "fit-content"},
+ )
+ st.div(id="paypal-result-message")
+
+def render_stripe_addon_buttons(user: AppUser):
+ for dollat_amt in settings.ADDON_AMOUNT_CHOICES:
+ render_stripe_addon_button(dollat_amt, user)
-def render_stripe_addon_button(amount: int, user: AppUser):
- confirm_purchase_modal = Modal("Confirm Purchase", key=f"confirm-purchase-{amount}")
- if st.button(f"${amount:,}", type="primary"):
- confirm_purchase_modal.open()
+
+def render_stripe_addon_button(dollat_amt: int, user: AppUser):
+ confirm_purchase_modal = Modal(
+ "Confirm Purchase", key=f"confirm-purchase-{dollat_amt}"
+ )
+ if st.button(f"${dollat_amt:,}", type="primary"):
+ if user.subscription:
+ confirm_purchase_modal.open()
+ else:
+ stripe_addon_checkout_redirect(user, dollat_amt)
if not confirm_purchase_modal.is_open():
return
@@ -382,7 +435,7 @@ def render_stripe_addon_button(amount: int, user: AppUser):
st.write(
f"""
Please confirm your purchase:
- **{amount * settings.ADDON_CREDITS_PER_DOLLAR:,} credits for ${amount}**.
+ **{dollat_amt * settings.ADDON_CREDITS_PER_DOLLAR:,} credits for ${dollat_amt}**.
""",
className="py-4 d-block text-center",
)
@@ -390,7 +443,7 @@ def render_stripe_addon_button(amount: int, user: AppUser):
if st.session_state.get("--confirm-purchase"):
success = st.run_in_thread(
user.subscription.stripe_attempt_addon_purchase,
- args=[amount],
+ args=[dollat_amt],
placeholder="Processing payment...",
)
if success is None:
@@ -407,6 +460,27 @@ def render_stripe_addon_button(amount: int, user: AppUser):
st.button("Buy", type="primary", key="--confirm-purchase")
+def stripe_addon_checkout_redirect(user: AppUser, dollat_amt: int):
+ from routers.account import account_route
+ from routers.account import payment_processing_route
+
+ line_item = available_subscriptions["addon"]["stripe"].copy()
+ line_item["quantity"] = dollat_amt * settings.ADDON_CREDITS_PER_DOLLAR
+ checkout_session = stripe.checkout.Session.create(
+ line_items=[line_item],
+ mode="payment",
+ success_url=get_route_url(payment_processing_route),
+ cancel_url=get_route_url(account_route),
+ customer=user.get_or_create_stripe_customer(),
+ invoice_creation={"enabled": True},
+ allow_promotion_codes=True,
+ saved_payment_method_options={
+ "payment_method_save": "enabled",
+ },
+ )
+ raise RedirectException(checkout_session.url, status_code=303)
+
+
def render_stripe_subscription_button(
*,
label: str,
@@ -419,16 +493,19 @@ def render_stripe_subscription_button(
st.write("Stripe subscription not available")
return
- if st.button(label, type=btn_type, key=key):
- create_stripe_checkout_session(user=user, plan=plan)
+ # IMPORTANT: key=... is needed here to maintain uniqueness
+ # of buttons with the same label. otherwise, all buttons
+ # will be the same to the server
+ if st.button(label, key=key, type=btn_type):
+ stripe_subscription_checkout_redirect(user=user, plan=plan)
-def create_stripe_checkout_session(user: AppUser, plan: PricingPlan):
+def stripe_subscription_checkout_redirect(user: AppUser, plan: PricingPlan):
from routers.account import account_route
from routers.account import payment_processing_route
- if user.subscription and user.subscription.plan == plan.db_value:
- # already subscribed to the same plan
+ if user.subscription:
+ # already subscribed to some plan
return
metadata = {settings.STRIPE_USER_SUBSCRIPTION_METADATA_FIELD: plan.key}
@@ -543,16 +620,17 @@ def render_billing_history(user: AppUser, limit: int = 50):
st.write("## Billing History", className="d-block")
st.table(
pd.DataFrame.from_records(
- columns=[""] * 3,
- data=[
- [
- txn.created_at.strftime("%m/%d/%Y"),
- txn.note(),
- f"${txn.charged_amount / 100:,.2f}",
- ]
+ [
+ {
+ "Date": txn.created_at.strftime("%m/%d/%Y"),
+ "Description": txn.reason_note(),
+ "Amount": f"-${txn.charged_amount / 100:,.2f}",
+ "Credits": f"+{txn.amount:,}",
+ "Balance": f"{txn.end_balance:,}",
+ }
for txn in txns[:limit]
- ],
- )
+ ]
+ ),
)
if txns.count() > limit:
st.caption(f"Showing only the most recent {limit} transactions.")
diff --git a/daras_ai_v2/exceptions.py b/daras_ai_v2/exceptions.py
index 78e3f053f..02acc4ec6 100644
--- a/daras_ai_v2/exceptions.py
+++ b/daras_ai_v2/exceptions.py
@@ -3,8 +3,16 @@
import typing
import requests
+from furl import furl
from loguru import logger
from requests import HTTPError
+from starlette.status import HTTP_402_PAYMENT_REQUIRED
+
+from daras_ai_v2 import settings
+
+if typing.TYPE_CHECKING:
+ from bots.models import SavedRun
+ from bots.models import AppUser
def raise_for_status(resp: requests.Response, is_user_url: bool = False):
@@ -47,9 +55,12 @@ def _response_preview(resp: requests.Response) -> bytes:
class UserError(Exception):
- def __init__(self, message: str, sentry_level: str = "info"):
+ def __init__(
+ self, message: str, sentry_level: str = "info", status_code: int = None
+ ):
self.message = message
self.sentry_level = sentry_level
+ self.status_code = status_code
super().__init__(message)
@@ -57,6 +68,41 @@ class GPUError(UserError):
pass
+class InsufficientCredits(UserError):
+ def __init__(self, user: "AppUser", sr: "SavedRun"):
+ from daras_ai_v2.base import SUBMIT_AFTER_LOGIN_Q
+
+ account_url = furl(settings.APP_BASE_URL) / "account/"
+ if user.is_anonymous:
+ account_url.query.params["next"] = sr.get_app_url(
+ query_params={SUBMIT_AFTER_LOGIN_Q: "1"},
+ )
+ # language=HTML
+ message = f"""
+
+Doh! Please login to run more Gooey.AI workflows. +
+ +You’ll receive {settings.LOGIN_USER_FREE_CREDITS} Credits when you sign up via your phone #, Google, Apple or GitHub account +and can purchase more for $1/100 Credits. +""" + else: + # language=HTML + message = f""" ++Doh! You’re out of Gooey.AI credits. +
+ ++Please buy more to run more workflows. +
+ +We’re always on discord if you’ve got any questions. +""" + + super().__init__(message, status_code=HTTP_402_PAYMENT_REQUIRED) + + FFMPEG_ERR_MSG = ( "Unsupported File Format\n\n" "We encountered an issue processing your file as it appears to be in a format not supported by our system or may be corrupted. " diff --git a/daras_ai_v2/paypal.py b/daras_ai_v2/paypal.py index 4c5229b98..eae5042d4 100644 --- a/daras_ai_v2/paypal.py +++ b/daras_ai_v2/paypal.py @@ -155,6 +155,8 @@ class Subscription(PaypalResource): billing_info: BillingInfo | None def cancel(self, *, reason: str = "cancellation_requested") -> None: + if self.status in ["CANCELLED", "EXPIRED"]: + return r = requests.post( str(self.get_resource_url() / "cancel"), headers=get_default_headers(), diff --git a/daras_ai_v2/settings.py b/daras_ai_v2/settings.py index 58878a9af..9b48a1c2e 100644 --- a/daras_ai_v2/settings.py +++ b/daras_ai_v2/settings.py @@ -263,7 +263,7 @@ ADMIN_EMAILS = config("ADMIN_EMAILS", cast=Csv(), default="") SUPPORT_EMAIL = "Gooey.AI Support- Your {{ title }} Gooey.AI run completed in {{ run_time_sec }} seconds. -
-- View output here: {{ app_url }} -
-
- Your prompt: {{ prompt }}
-
Your {{ recipe_title }} Gooey.AI run completed in {{ run_time_sec }} seconds.
+View output here: {{ app_url }}
-- We can’t wait to see what you build with Gooey! -
+{% if prompt %} +Your prompt: {{ prompt }}
We can’t wait to see what you build with Gooey!
- Cheers,
-
+ Cheers,
The Gooey.AI Team