Skip to content

Commit

Permalink
perf optimization on startup time for celery tasks
Browse files Browse the repository at this point in the history
report complete run time of celery task to frontend, not just the steps
fix sentry urls
  • Loading branch information
devxpy committed Jul 7, 2024
1 parent fd80e78 commit 69b4a81
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 99 deletions.
130 changes: 54 additions & 76 deletions celeryapp/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,114 +26,93 @@
from daras_ai_v2.settings import templates
from gooey_ui.pubsub import realtime_push
from gooey_ui.state import set_query_params
from gooeysite.bg_db_conn import db_middleware, next_db_safe
from gooeysite.bg_db_conn import db_middleware

DEFAULT_RUN_STATUS = "Running..."


@app.task
def gui_runner(
*,
page_cls: typing.Type[BasePage],
user_id: str,
user_id: int,
run_id: str,
uid: str,
channel: str,
):
def event_processor(event, hint):
event["request"] = {
"method": "POST",
"url": page.app_url(query_params=st.get_query_params()),
"data": st.session_state,
}
return event

page = page_cls(request=SimpleNamespace(user=AppUser.objects.get(id=user_id)))
page.setup_sentry(event_processor=event_processor)
sr = page.run_doc_sr(run_id, uid)
st.set_session_state(sr.to_dict())
set_query_params(dict(run_id=run_id, uid=uid))

run_time = 0
yield_val = None
start_time = time()
error_msg = None

@db_middleware
def save(done=False):
def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False):
if isinstance(yield_val, tuple):
run_status, extra_output = yield_val
else:
run_status = yield_val
extra_output = {}

if done:
# clear run status
run_status = None
else:
# set run status to the yield value of generator
run_status = yield_val or "Running..."
if isinstance(run_status, tuple):
run_status, extra_output = run_status
else:
extra_output = {}
# set run status and run time
status = {
StateKeys.run_time: run_time,
StateKeys.error_msg: error_msg,
StateKeys.run_status: run_status,
}
run_status = run_status or DEFAULT_RUN_STATUS

output = (
status
|
# extract outputs from local state
# extract status of the run
{
StateKeys.error_msg: error_msg,
StateKeys.run_time: time() - start_time,
StateKeys.run_status: run_status,
}
# extract outputs from local state
| {
k: v
for k, v in st.session_state.items()
if k in page.ResponseModel.__fields__
}
# add extra outputs from the run
| extra_output
)

# send outputs to ui
realtime_push(channel, output)
# save to db
page.dump_state_to_sr(st.session_state | output, sr)

user = AppUser.objects.get(id=user_id)
page = page_cls(request=SimpleNamespace(user=user))
page.setup_sentry()
sr = page.run_doc_sr(run_id, uid)
st.set_session_state(sr.to_dict())
set_query_params(dict(run_id=run_id, uid=uid))

try:
gen = page.main(sr, st.session_state)
save()
while True:
# record time
start_time = time()
save_on_step()
for val in page.main(sr, st.session_state):
save_on_step(val)
# render errors nicely
except Exception as e:
if isinstance(e, HTTPException) and e.status_code == 402:
error_msg = page.generate_credit_error_message(run_id, uid)
try:
# advance the generator (to further progress of run())
yield_val = next_db_safe(gen)
# increment total time taken after every iteration
run_time += time() - start_time
continue
# run completed
except StopIteration:
run_time += time() - start_time
sr.transaction, sr.price = page.deduct_credits(st.session_state)
break
# render errors nicely
except Exception as e:
run_time += time() - start_time

if isinstance(e, HTTPException) and e.status_code == 402:
error_msg = page.generate_credit_error_message(run_id, uid)
try:
raise UserError(error_msg) from e
except UserError as e:
sentry_sdk.capture_exception(e, level=e.sentry_level)
break

if isinstance(e, UserError):
sentry_level = e.sentry_level
else:
sentry_level = "error"
traceback.print_exc()
sentry_sdk.capture_exception(e, level=sentry_level)
error_msg = err_msg_for_exc(e)
break
finally:
save()
raise UserError(error_msg) from e
except UserError as e:
sentry_sdk.capture_exception(e, level=e.sentry_level)
else:
if isinstance(e, UserError):
sentry_level = e.sentry_level
else:
sentry_level = "error"
traceback.print_exc()
sentry_sdk.capture_exception(e, level=sentry_level)
error_msg = err_msg_for_exc(e)
# run completed successfully, deduct credits
else:
sr.transaction, sr.price = page.deduct_credits(st.session_state)
finally:
save(done=True)
save_on_step(done=True)
if not sr.is_api_call:
send_email_on_completion(page, sr)
run_low_balance_email_check(uid)
run_low_balance_email_check(user)


def err_msg_for_exc(e: Exception):
Expand Down Expand Up @@ -162,11 +141,10 @@ def err_msg_for_exc(e: Exception):
return f"{type(e).__name__}: {e}"


def run_low_balance_email_check(uid: str):
def run_low_balance_email_check(user: AppUser):
# don't send email if feature is disabled
if not settings.LOW_BALANCE_EMAIL_ENABLED:
return
user = AppUser.objects.get(uid=uid)
# don't send email if user is not paying or has enough balance
if not user.is_paying or user.balance > settings.LOW_BALANCE_EMAIL_CREDITS:
return
Expand Down
42 changes: 27 additions & 15 deletions daras_ai_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,31 @@ def clean_query_params(cls, *, example_id, run_id, uid) -> dict:
def get_dynamic_meta_title(self) -> str | None:
return None

def setup_sentry(self, event_processor: typing.Callable = None):
def add_user_to_event(event, hint):
user = self.request and self.request.user
if not user:
return event
def setup_sentry(self):
with sentry_sdk.configure_scope() as scope:
scope.set_transaction_name(
"/" + self.slug_versions[0], source=TRANSACTION_SOURCE_ROUTE
)
scope.add_event_processor(self.sentry_event_set_request)
scope.add_event_processor(self.sentry_event_set_user)

def sentry_event_set_request(self, event, hint):
request = event.setdefault("request", {})
request.setdefault("method", "POST")
request["data"] = st.session_state
if url := request.get("url"):
f = furl(url)
request["url"] = str(
furl(settings.APP_BASE_URL, path=f.pathstr, query=f.querystr).url
)
else:
request["url"] = self.app_url(
tab=self.tab, query_params=st.get_query_params()
)
return event

def sentry_event_set_user(self, event, hint):
if user := self.request and self.request.user:
event["user"] = {
"id": user.id,
"name": user.display_name,
Expand All @@ -282,20 +302,12 @@ def add_user_to_event(event, hint):
"is_anonymous",
"is_disabled",
"disable_safety_checker",
"disable_rate_limits",
"created_at",
]
},
}
return event

with sentry_sdk.configure_scope() as scope:
scope.set_extra("base_url", self.app_url())
scope.set_transaction_name(
"/" + self.slug_versions[0], source=TRANSACTION_SOURCE_ROUTE
)
scope.add_event_processor(add_user_to_event)
if event_processor:
scope.add_event_processor(event_processor)
return event

def refresh_state(self):
_, run_id, uid = extract_query_params(gooey_get_query_params())
Expand Down
16 changes: 8 additions & 8 deletions tests/test_low_balance_email_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test_dont_send_email_if_feature_is_disabled(transactional_db):
uid="test_user", is_paying=True, balance=0, is_anonymous=False
)
settings.LOW_BALANCE_EMAIL_ENABLED = False
run_low_balance_email_check(user.uid)
run_low_balance_email_check(user)
assert not pytest_outbox


Expand All @@ -21,7 +21,7 @@ def test_dont_send_email_if_user_is_not_paying(transactional_db):
uid="test_user", is_paying=False, balance=0, is_anonymous=False
)
settings.LOW_BALANCE_EMAIL_ENABLED = True
run_low_balance_email_check(user.uid)
run_low_balance_email_check(user)
assert not pytest_outbox


Expand All @@ -31,7 +31,7 @@ def test_dont_send_email_if_user_has_enough_balance(transactional_db):
)
settings.LOW_BALANCE_EMAIL_CREDITS = 100
settings.LOW_BALANCE_EMAIL_ENABLED = True
run_low_balance_email_check(user.uid)
run_low_balance_email_check(user)
assert not pytest_outbox


Expand All @@ -46,7 +46,7 @@ def test_dont_send_email_if_user_has_been_emailed_recently(transactional_db):
settings.LOW_BALANCE_EMAIL_ENABLED = True
settings.LOW_BALANCE_EMAIL_DAYS = 1
settings.LOW_BALANCE_EMAIL_CREDITS = 100
run_low_balance_email_check(user.uid)
run_low_balance_email_check(user)
assert not pytest_outbox


Expand Down Expand Up @@ -77,14 +77,14 @@ def test_send_email_if_user_has_been_email_recently_but_made_a_purchase(
settings.LOW_BALANCE_EMAIL_ENABLED = True
settings.LOW_BALANCE_EMAIL_DAYS = 1
settings.LOW_BALANCE_EMAIL_CREDITS = 100
run_low_balance_email_check(user.uid)
run_low_balance_email_check(user)

assert len(pytest_outbox) == 1
assert " 22" in pytest_outbox[0]["html_body"]
assert " 78" in pytest_outbox[0]["html_body"]

pytest_outbox.clear()
run_low_balance_email_check(user.uid)
run_low_balance_email_check(user)
assert not pytest_outbox


Expand Down Expand Up @@ -114,7 +114,7 @@ def test_send_email(transactional_db):
settings.LOW_BALANCE_EMAIL_DAYS = 1
settings.LOW_BALANCE_EMAIL_CREDITS = 100

run_low_balance_email_check(user.uid)
run_low_balance_email_check(user)
assert len(pytest_outbox) == 1
body = pytest_outbox[0]["html_body"]
assert " 66" in body
Expand All @@ -123,5 +123,5 @@ def test_send_email(transactional_db):
assert " 100" not in body

pytest_outbox.clear()
run_low_balance_email_check(user.uid)
run_low_balance_email_check(user)
assert not pytest_outbox

0 comments on commit 69b4a81

Please sign in to comment.