From 69b4a814ae030ad3d0101f0761a586f7d6b683df Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Mon, 8 Jul 2024 00:38:17 +0530 Subject: [PATCH] perf optimization on startup time for celery tasks report complete run time of celery task to frontend, not just the steps fix sentry urls --- celeryapp/tasks.py | 130 +++++++++++--------------- daras_ai_v2/base.py | 42 ++++++--- tests/test_low_balance_email_check.py | 16 ++-- 3 files changed, 89 insertions(+), 99 deletions(-) diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index 240ade4b3..b2a7b4327 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -26,114 +26,93 @@ from daras_ai_v2.settings import templates from gooey_ui.pubsub import realtime_push from gooey_ui.state import set_query_params -from gooeysite.bg_db_conn import db_middleware, next_db_safe +from gooeysite.bg_db_conn import db_middleware + +DEFAULT_RUN_STATUS = "Running..." @app.task def gui_runner( *, page_cls: typing.Type[BasePage], - user_id: str, + user_id: int, run_id: str, uid: str, channel: str, ): - def event_processor(event, hint): - event["request"] = { - "method": "POST", - "url": page.app_url(query_params=st.get_query_params()), - "data": st.session_state, - } - return event - - page = page_cls(request=SimpleNamespace(user=AppUser.objects.get(id=user_id))) - page.setup_sentry(event_processor=event_processor) - sr = page.run_doc_sr(run_id, uid) - st.set_session_state(sr.to_dict()) - set_query_params(dict(run_id=run_id, uid=uid)) - - run_time = 0 - yield_val = None + start_time = time() error_msg = None @db_middleware - def save(done=False): + def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False): + if isinstance(yield_val, tuple): + run_status, extra_output = yield_val + else: + run_status = yield_val + extra_output = {} + if done: - # clear run status run_status = None else: - # set run status to the yield value of generator - run_status = yield_val or "Running..." - if isinstance(run_status, tuple): - run_status, extra_output = run_status - else: - extra_output = {} - # set run status and run time - status = { - StateKeys.run_time: run_time, - StateKeys.error_msg: error_msg, - StateKeys.run_status: run_status, - } + run_status = run_status or DEFAULT_RUN_STATUS + output = ( - status - | - # extract outputs from local state + # extract status of the run { + StateKeys.error_msg: error_msg, + StateKeys.run_time: time() - start_time, + StateKeys.run_status: run_status, + } + # extract outputs from local state + | { k: v for k, v in st.session_state.items() if k in page.ResponseModel.__fields__ } + # add extra outputs from the run | extra_output ) + # send outputs to ui realtime_push(channel, output) # save to db page.dump_state_to_sr(st.session_state | output, sr) + user = AppUser.objects.get(id=user_id) + page = page_cls(request=SimpleNamespace(user=user)) + page.setup_sentry() + sr = page.run_doc_sr(run_id, uid) + st.set_session_state(sr.to_dict()) + set_query_params(dict(run_id=run_id, uid=uid)) + try: - gen = page.main(sr, st.session_state) - save() - while True: - # record time - start_time = time() + save_on_step() + for val in page.main(sr, st.session_state): + save_on_step(val) + # render errors nicely + except Exception as e: + if isinstance(e, HTTPException) and e.status_code == 402: + error_msg = page.generate_credit_error_message(run_id, uid) try: - # advance the generator (to further progress of run()) - yield_val = next_db_safe(gen) - # increment total time taken after every iteration - run_time += time() - start_time - continue - # run completed - except StopIteration: - run_time += time() - start_time - sr.transaction, sr.price = page.deduct_credits(st.session_state) - break - # render errors nicely - except Exception as e: - run_time += time() - start_time - - if isinstance(e, HTTPException) and e.status_code == 402: - error_msg = page.generate_credit_error_message(run_id, uid) - try: - raise UserError(error_msg) from e - except UserError as e: - sentry_sdk.capture_exception(e, level=e.sentry_level) - break - - if isinstance(e, UserError): - sentry_level = e.sentry_level - else: - sentry_level = "error" - traceback.print_exc() - sentry_sdk.capture_exception(e, level=sentry_level) - error_msg = err_msg_for_exc(e) - break - finally: - save() + raise UserError(error_msg) from e + except UserError as e: + sentry_sdk.capture_exception(e, level=e.sentry_level) + else: + if isinstance(e, UserError): + sentry_level = e.sentry_level + else: + sentry_level = "error" + traceback.print_exc() + sentry_sdk.capture_exception(e, level=sentry_level) + error_msg = err_msg_for_exc(e) + # run completed successfully, deduct credits + else: + sr.transaction, sr.price = page.deduct_credits(st.session_state) finally: - save(done=True) + save_on_step(done=True) if not sr.is_api_call: send_email_on_completion(page, sr) - run_low_balance_email_check(uid) + run_low_balance_email_check(user) def err_msg_for_exc(e: Exception): @@ -162,11 +141,10 @@ def err_msg_for_exc(e: Exception): return f"{type(e).__name__}: {e}" -def run_low_balance_email_check(uid: str): +def run_low_balance_email_check(user: AppUser): # don't send email if feature is disabled if not settings.LOW_BALANCE_EMAIL_ENABLED: return - user = AppUser.objects.get(uid=uid) # don't send email if user is not paying or has enough balance if not user.is_paying or user.balance > settings.LOW_BALANCE_EMAIL_CREDITS: return diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index fba67cbc0..b1324fb68 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -262,11 +262,31 @@ def clean_query_params(cls, *, example_id, run_id, uid) -> dict: def get_dynamic_meta_title(self) -> str | None: return None - def setup_sentry(self, event_processor: typing.Callable = None): - def add_user_to_event(event, hint): - user = self.request and self.request.user - if not user: - return event + def setup_sentry(self): + with sentry_sdk.configure_scope() as scope: + scope.set_transaction_name( + "/" + self.slug_versions[0], source=TRANSACTION_SOURCE_ROUTE + ) + scope.add_event_processor(self.sentry_event_set_request) + scope.add_event_processor(self.sentry_event_set_user) + + def sentry_event_set_request(self, event, hint): + request = event.setdefault("request", {}) + request.setdefault("method", "POST") + request["data"] = st.session_state + if url := request.get("url"): + f = furl(url) + request["url"] = str( + furl(settings.APP_BASE_URL, path=f.pathstr, query=f.querystr).url + ) + else: + request["url"] = self.app_url( + tab=self.tab, query_params=st.get_query_params() + ) + return event + + def sentry_event_set_user(self, event, hint): + if user := self.request and self.request.user: event["user"] = { "id": user.id, "name": user.display_name, @@ -282,20 +302,12 @@ def add_user_to_event(event, hint): "is_anonymous", "is_disabled", "disable_safety_checker", + "disable_rate_limits", "created_at", ] }, } - return event - - with sentry_sdk.configure_scope() as scope: - scope.set_extra("base_url", self.app_url()) - scope.set_transaction_name( - "/" + self.slug_versions[0], source=TRANSACTION_SOURCE_ROUTE - ) - scope.add_event_processor(add_user_to_event) - if event_processor: - scope.add_event_processor(event_processor) + return event def refresh_state(self): _, run_id, uid = extract_query_params(gooey_get_query_params()) diff --git a/tests/test_low_balance_email_check.py b/tests/test_low_balance_email_check.py index 66203a19b..dc9c74148 100644 --- a/tests/test_low_balance_email_check.py +++ b/tests/test_low_balance_email_check.py @@ -12,7 +12,7 @@ def test_dont_send_email_if_feature_is_disabled(transactional_db): uid="test_user", is_paying=True, balance=0, is_anonymous=False ) settings.LOW_BALANCE_EMAIL_ENABLED = False - run_low_balance_email_check(user.uid) + run_low_balance_email_check(user) assert not pytest_outbox @@ -21,7 +21,7 @@ def test_dont_send_email_if_user_is_not_paying(transactional_db): uid="test_user", is_paying=False, balance=0, is_anonymous=False ) settings.LOW_BALANCE_EMAIL_ENABLED = True - run_low_balance_email_check(user.uid) + run_low_balance_email_check(user) assert not pytest_outbox @@ -31,7 +31,7 @@ def test_dont_send_email_if_user_has_enough_balance(transactional_db): ) settings.LOW_BALANCE_EMAIL_CREDITS = 100 settings.LOW_BALANCE_EMAIL_ENABLED = True - run_low_balance_email_check(user.uid) + run_low_balance_email_check(user) assert not pytest_outbox @@ -46,7 +46,7 @@ def test_dont_send_email_if_user_has_been_emailed_recently(transactional_db): settings.LOW_BALANCE_EMAIL_ENABLED = True settings.LOW_BALANCE_EMAIL_DAYS = 1 settings.LOW_BALANCE_EMAIL_CREDITS = 100 - run_low_balance_email_check(user.uid) + run_low_balance_email_check(user) assert not pytest_outbox @@ -77,14 +77,14 @@ def test_send_email_if_user_has_been_email_recently_but_made_a_purchase( settings.LOW_BALANCE_EMAIL_ENABLED = True settings.LOW_BALANCE_EMAIL_DAYS = 1 settings.LOW_BALANCE_EMAIL_CREDITS = 100 - run_low_balance_email_check(user.uid) + run_low_balance_email_check(user) assert len(pytest_outbox) == 1 assert " 22" in pytest_outbox[0]["html_body"] assert " 78" in pytest_outbox[0]["html_body"] pytest_outbox.clear() - run_low_balance_email_check(user.uid) + run_low_balance_email_check(user) assert not pytest_outbox @@ -114,7 +114,7 @@ def test_send_email(transactional_db): settings.LOW_BALANCE_EMAIL_DAYS = 1 settings.LOW_BALANCE_EMAIL_CREDITS = 100 - run_low_balance_email_check(user.uid) + run_low_balance_email_check(user) assert len(pytest_outbox) == 1 body = pytest_outbox[0]["html_body"] assert " 66" in body @@ -123,5 +123,5 @@ def test_send_email(transactional_db): assert " 100" not in body pytest_outbox.clear() - run_low_balance_email_check(user.uid) + run_low_balance_email_check(user) assert not pytest_outbox