Skip to content

Commit

Permalink
rename gui_runner -> runner_task
Browse files Browse the repository at this point in the history
optimize run time of runner_task: check for credits only once in entire call chain, separate post run stuff into a separate task
fix run complete email for empty prompts
  • Loading branch information
devxpy committed Jul 12, 2024
1 parent fa6d9e5 commit b5bc34b
Show file tree
Hide file tree
Showing 11 changed files with 280 additions and 201 deletions.
2 changes: 1 addition & 1 deletion bots/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
Workflow,
)
from bots.tasks import create_personal_channels_for_all_members
from celeryapp.tasks import gui_runner
from celeryapp.tasks import runner_task
from daras_ai_v2.fastapi_tricks import get_route_url
from gooeysite.custom_actions import export_to_excel, export_to_csv
from gooeysite.custom_filters import (
Expand Down
27 changes: 22 additions & 5 deletions bots/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,24 @@ class SavedRun(models.Model):

state = models.JSONField(default=dict, blank=True, encoder=PostgresJSONEncoder)

error_msg = models.TextField(default="", blank=True)
error_msg = models.TextField(
default="",
blank=True,
help_text="The error message. If this is not set, the run is deemed successful.",
)
run_time = models.DurationField(default=datetime.timedelta, blank=True)
run_status = models.TextField(default="", blank=True)

error_code = models.IntegerField(
null=True,
default=None,
blank=True,
help_text="The HTTP status code of the error. If this is not set, 500 is assumed.",
)
error_type = models.TextField(
default="", blank=True, help_text="The exception type"
)

hidden = models.BooleanField(default=False)
is_flagged = models.BooleanField(default=False)

Expand Down Expand Up @@ -282,9 +296,12 @@ def __str__(self):
def parent_published_run(self) -> typing.Optional["PublishedRun"]:
return self.parent_version and self.parent_version.published_run

def get_app_url(self):
def get_app_url(self, query_params: dict = None):
return Workflow(self.workflow).page_cls.app_url(
example_id=self.example_id, run_id=self.run_id, uid=self.uid
example_id=self.example_id,
run_id=self.run_id,
uid=self.uid,
query_params=query_params,
)

def to_dict(self) -> dict:
Expand Down Expand Up @@ -1624,9 +1641,9 @@ def duplicate(
visibility=visibility,
)

def get_app_url(self):
def get_app_url(self, query_params: dict = None):
return Workflow(self.workflow).page_cls.app_url(
example_id=self.published_run_id
example_id=self.published_run_id, query_params=query_params
)

def add_version(
Expand Down
83 changes: 51 additions & 32 deletions celeryapp/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import gooey_ui as st
from app_users.models import AppUser, AppUserTransaction
from bots.admin_links import change_obj_url
from bots.models import SavedRun, Platform
from bots.models import SavedRun, Platform, Workflow
from celeryapp.celeryconfig import app
from daras_ai.image_input import truncate_text_words
from daras_ai_v2 import settings
Expand All @@ -24,22 +24,24 @@
from daras_ai_v2.settings import templates
from gooey_ui.pubsub import realtime_push
from gooey_ui.state import set_query_params
from gooeysite.bg_db_conn import db_middleware, next_db_safe
from payments.tasks import run_auto_recharge_gracefully
from gooeysite.bg_db_conn import db_middleware
from payments.auto_recharge import (
should_attempt_auto_recharge,
run_auto_recharge_gracefully,
)

DEFAULT_RUN_STATUS = "Running..."


@app.task
def gui_runner(
def runner_task(
*,
page_cls: typing.Type[BasePage],
user_id: int,
run_id: str,
uid: str,
channel: str,
):
) -> int:
start_time = time()
error_msg = None

Expand Down Expand Up @@ -89,36 +91,50 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False
save_on_step()
for val in page.main(sr, st.session_state):
save_on_step(val)

# render errors nicely
except Exception as e:
if isinstance(e, HTTPException) and e.status_code == 402:
error_msg = page.generate_credit_error_message(run_id, uid)
try:
raise UserError(error_msg) from e
except UserError as e:
sentry_sdk.capture_exception(e, level=e.sentry_level)
if isinstance(e, UserError):
sentry_level = e.sentry_level
else:
if isinstance(e, UserError):
sentry_level = e.sentry_level
else:
sentry_level = "error"
traceback.print_exc()
sentry_sdk.capture_exception(e, level=sentry_level)
error_msg = err_msg_for_exc(e)
sentry_level = "error"
traceback.print_exc()
sentry_sdk.capture_exception(e, level=sentry_level)
error_msg = err_msg_for_exc(e)
sr.error_type = type(e).__qualname__
sr.error_code = getattr(e, "status_code", None)

# run completed successfully, deduct credits
else:
sr.transaction, sr.price = page.deduct_credits(st.session_state)

# save everything, mark run as completed
finally:
save_on_step(done=True)
if not sr.is_api_call:
send_email_on_completion(page, sr)

run_low_balance_email_check(user)
run_auto_recharge_gracefully(uid)
return sr.id


@app.task
def post_runner_tasks(saved_run_id: int):
sr = SavedRun.objects.get(id=saved_run_id)
user = AppUser.objects.get(uid=sr.uid)

if not sr.is_api_call:
send_email_on_completion(sr)

if should_attempt_auto_recharge(user):
run_auto_recharge_gracefully(user)

run_low_balance_email_check(user)


def err_msg_for_exc(e: Exception):
if isinstance(e, requests.HTTPError):
if isinstance(e, UserError):
return e.message
elif isinstance(e, HTTPException):
return f"(HTTP {e.status_code}) {e.detail})"
elif isinstance(e, requests.HTTPError):
response: requests.Response = e.response
try:
err_body = response.json()
Expand All @@ -135,10 +151,6 @@ def err_msg_for_exc(e: Exception):
return f"(GPU) {err_type}: {err_str}"
err_str = str(err_body)
return f"(HTTP {response.status_code}) {html.escape(err_str[:1000])}"
elif isinstance(e, HTTPException):
return f"(HTTP {e.status_code}) {e.detail})"
elif isinstance(e, UserError):
return e.message
else:
return f"{type(e).__name__}: {e}"

Expand Down Expand Up @@ -179,7 +191,7 @@ def run_low_balance_email_check(user: AppUser):
user.save(update_fields=["low_balance_email_sent_at"])


def send_email_on_completion(page: BasePage, sr: SavedRun):
def send_email_on_completion(sr: SavedRun):
run_time_sec = sr.run_time.total_seconds()
if (
run_time_sec <= settings.SEND_RUN_EMAIL_AFTER_SEC
Expand All @@ -191,9 +203,16 @@ def send_email_on_completion(page: BasePage, sr: SavedRun):
)
if not to_address:
return
prompt = (page.preview_input(sr.state) or "").strip()
title = (sr.state.get("__title") or page.title).strip()
subject = f"🌻 “{truncate_text_words(prompt, maxlen=50)}{title} is done"

workflow = Workflow(sr.workflow)
page_cls = workflow.page_cls
prompt = (page_cls.preview_input(sr.state) or "").strip().replace("\n", " ")
recipe_title = page_cls.get_recipe_title()

subject = (
f"🌻 “{truncate_text_words(prompt, maxlen=50) or 'Run'}{recipe_title} is done"
)

send_email_via_postmark(
from_address=settings.SUPPORT_EMAIL,
to_address=to_address,
Expand All @@ -202,7 +221,7 @@ def send_email_on_completion(page: BasePage, sr: SavedRun):
run_time_sec=round(run_time_sec),
app_url=sr.get_app_url(),
prompt=prompt,
title=title,
recipe_title=recipe_title,
),
message_stream="gooey-ai-workflows",
)
Expand Down
97 changes: 41 additions & 56 deletions daras_ai_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
)
from starlette.requests import Request

from daras_ai_v2.exceptions import UserError
import gooey_ui as st
from app_users.models import AppUser, AppUserTransaction
from bots.models import (
Expand All @@ -49,6 +48,7 @@
from daras_ai_v2.db import (
ANONYMOUS_USER_COOKIE,
)
from daras_ai_v2.exceptions import InsufficientCredits
from daras_ai_v2.fastapi_tricks import get_route_path
from daras_ai_v2.grid_layout_widget import grid_layout
from daras_ai_v2.html_spinner_widget import html_spinner
Expand Down Expand Up @@ -83,12 +83,11 @@
from gooey_ui.components.modal import Modal
from gooey_ui.components.pills import pill
from gooey_ui.pubsub import realtime_pull
from routers.account import AccountTabs
from payments.auto_recharge import (
AutoRechargeException,
auto_recharge_user,
should_attempt_auto_recharge,
run_auto_recharge_gracefully,
)
from routers.account import AccountTabs
from routers.root import RecipeTabs

DEFAULT_META_IMG = (
Expand Down Expand Up @@ -1415,30 +1414,9 @@ def _render_help(self):
def render_usage_guide(self):
raise NotImplementedError

def run_with_auto_recharge(self, state: dict) -> typing.Iterator[str | None]:
if not self.check_credits() and should_attempt_auto_recharge(self.request.user):
yield "Low balance detected. Recharging..."
try:
auto_recharge_user(uid=self.request.user.uid)
except AutoRechargeException as e:
# raise this error only if another auto-recharge
# procedure didn't complete successfully
self.request.user.refresh_from_db()
if not self.check_credits():
raise UserError(str(e)) from e
else:
self.request.user.refresh_from_db()

if not self.check_credits():
example_id, run_id, uid = extract_query_params(gooey_get_query_params())
error_msg = self.generate_credit_error_message(
example_id=example_id, run_id=run_id, uid=uid
)
raise UserError(error_msg)

yield from self.run(state)

def main(self, sr: SavedRun, state: dict) -> typing.Iterator[str | None]:
yield from self.ensure_credits_and_auto_recharge(sr, state)

yield from call_recipe_functions(
saved_run=sr,
current_user=self.request.user,
Expand All @@ -1465,15 +1443,12 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
response = self.ResponseModel.construct()

# run the recipe
gen = self.run_v2(request, response)
while True:
try:
val = next(gen)
except StopIteration:
break
finally:
try:
for val in self.run_v2(request, response):
state.update(response.dict(exclude_unset=True))
yield val
yield val
finally:
state.update(response.dict(exclude_unset=True))

# validate the response if successful
self.ResponseModel.validate(response)
Expand Down Expand Up @@ -1634,15 +1609,7 @@ def on_submit(self):
st.session_state[StateKeys.error_msg] = e.detail.get("error", "")
return

if not self.check_credits() and not should_attempt_auto_recharge(
self.request.user
):
# insufficient balance for this run and auto-recharge isn't setup
sr.run_status = ""
sr.error_msg = self.generate_credit_error_message(sr.run_id, sr.uid)
sr.save(update_fields=["run_status", "error_msg"])
else:
self.call_runner_task(sr)
self.call_runner_task(sr)

raise RedirectException(self.app_url(run_id=sr.run_id, uid=sr.uid))

Expand Down Expand Up @@ -1715,15 +1682,19 @@ def dump_state_to_sr(self, state: dict, sr: SavedRun):
)

def call_runner_task(self, sr: SavedRun):
from celeryapp.tasks import gui_runner

return gui_runner.delay(
page_cls=self.__class__,
user_id=self.request.user.id,
run_id=sr.run_id,
uid=sr.uid,
channel=self.realtime_channel_name(sr.run_id, sr.uid),
from celeryapp.tasks import runner_task, post_runner_tasks

chain = (
runner_task.s(
page_cls=self.__class__,
user_id=self.request.user.id,
run_id=sr.run_id,
uid=sr.uid,
channel=self.realtime_channel_name(sr.run_id, sr.uid),
)
| post_runner_tasks.s()
)
return chain.apply_async()

@classmethod
def realtime_channel_name(cls, run_id, uid):
Expand Down Expand Up @@ -2099,13 +2070,27 @@ def run_as_api_tab(self):

manage_api_keys(self.request.user)

def check_credits(self) -> bool:
def ensure_credits_and_auto_recharge(self, sr: SavedRun, state: dict):
if not settings.CREDITS_TO_DEDUCT_PER_RUN:
return True

return
assert self.request, "request must be set to check credits"
assert self.request.user, "request.user must be set to check credits"
return self.request.user.balance >= self.get_price_roundoff(st.session_state)

user = self.request.user
price = self.get_price_roundoff(state)

if user.balance >= price:
return

if should_attempt_auto_recharge(user):
yield "Low balance detected. Recharging..."
run_auto_recharge_gracefully(user)
user.refresh_from_db()

if user.balance >= price:
return

raise InsufficientCredits(self.request.user, sr)

def deduct_credits(self, state: dict) -> tuple[AppUserTransaction, int]:
assert (
Expand Down
Loading

0 comments on commit b5bc34b

Please sign in to comment.