From c649edc12c61b3b4e5d54266dd62f88c3565e687 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Fri, 6 Sep 2024 15:19:15 +0530 Subject: [PATCH] Remove usage of `SimpleNamespace` for request handling in BasePage Move BasePage.run_user -> cached property current_sr_user --- bots/admin.py | 7 ++---- celeryapp/tasks.py | 6 +---- daras_ai_v2/base.py | 54 +++++++++++++++++++++++++++++-------------- recipes/LipsyncTTS.py | 6 ++--- recipes/VideoBots.py | 8 ++----- routers/api.py | 11 ++------- routers/root.py | 29 ++++++++--------------- tests/test_pricing.py | 21 +++++------------ 8 files changed, 62 insertions(+), 80 deletions(-) diff --git a/bots/admin.py b/bots/admin.py index b7e696dfd..8b52e233f 100644 --- a/bots/admin.py +++ b/bots/admin.py @@ -1,6 +1,5 @@ import datetime import json -from types import SimpleNamespace import django.db.models from django import forms @@ -439,10 +438,8 @@ def rerun_tasks(self, request, queryset): sr: SavedRun for sr in queryset.all(): page = Workflow(sr.workflow).page_cls( - request=SimpleNamespace( - user=AppUser.objects.get(uid=sr.uid), - query_params=dict(run_id=sr.run_id, uid=sr.uid), - ) + user=AppUser.objects.get(uid=sr.uid), + query_params=dict(run_id=sr.run_id, uid=sr.uid), ) page.call_runner_task(sr, deduct_credits=False) self.message_user( diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index b2e257365..3fed1442f 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -4,7 +4,6 @@ import traceback import typing from time import time -from types import SimpleNamespace import gooey_gui as gui import requests @@ -92,10 +91,7 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False page.dump_state_to_sr(gui.session_state | output, sr) page = page_cls( - request=SimpleNamespace( - user=AppUser.objects.get(id=user_id), - query_params=dict(run_id=run_id, uid=uid), - ), + user=AppUser.objects.get(id=user_id), query_params=dict(run_id=run_id, uid=uid) ) page.setup_sentry() sr = page.current_sr diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 90432553c..870debdb1 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -12,7 +12,6 @@ from itertools import pairwise from random import Random from time import sleep -from types import SimpleNamespace import gooey_gui as gui import sentry_sdk @@ -26,7 +25,6 @@ from sentry_sdk.tracing import ( TRANSACTION_SOURCE_ROUTE, ) -from starlette.requests import Request from app_users.models import AppUser, AppUserTransaction from bots.models import ( @@ -94,7 +92,6 @@ MAX_SEED = 4294967294 gooey_rng = Random() - SUBMIT_AFTER_LOGIN_Q = "submitafterlogin" @@ -117,6 +114,12 @@ class StateKeys: hidden = "__hidden" +class BasePageRequest: + user: AppUser | None + session: dict + query_params: dict + + class BasePage: title: str workflow: Workflow @@ -154,14 +157,20 @@ def __init__( self, *, tab: RecipeTabs = RecipeTabs.run, - request: Request | SimpleNamespace | None = None, - run_user: AppUser | None = None, + request: BasePageRequest | None = None, + user: AppUser | None = None, + request_session: dict | None = None, + query_params: dict | None = None, ): - if request is None: - request = SimpleNamespace(user=None, query_params={}) self.tab = tab + + if not request: + request = BasePageRequest() + request.user = user + request.session = request_session or {} + request.query_params = query_params or {} + self.request = request - self.run_user = run_user @classmethod def api_endpoint(cls) -> str: @@ -349,7 +358,7 @@ def _render_header(self): with gui.div(className="d-flex justify-content-between mt-4"): with gui.div(className="d-lg-flex d-block align-items-center"): - if not tbreadcrumbs.has_breadcrumbs() and not self.run_user: + if not tbreadcrumbs.has_breadcrumbs() and not self.current_sr_user: self._render_title(tbreadcrumbs.h1_title) if tbreadcrumbs: @@ -362,7 +371,7 @@ def _render_header(self): if is_example: author = pr.created_by else: - author = self.run_user or sr.get_creator() + author = self.current_sr_user or sr.get_creator() if not is_root_example: self.render_author(author) @@ -386,7 +395,7 @@ def _render_header(self): self._render_published_run_save_buttons(sr=sr, pr=pr) self._render_social_buttons(show_button_text=not show_save_buttons) - if tbreadcrumbs.has_breadcrumbs() or self.run_user: + if tbreadcrumbs.has_breadcrumbs() or self.current_sr_user: # only render title here if the above row was not empty self._render_title(tbreadcrumbs.h1_title) @@ -810,7 +819,7 @@ def get_explore_image(self) -> str: return meta_preview_url(img, fallback_img) def _user_disabled_check(self): - if self.run_user and self.run_user.is_disabled: + if self.current_sr_user and self.current_sr_user.is_disabled: msg = ( "This Gooey.AI account has been disabled for violating our [Terms of Service](/terms). " "Contact us at support@gooey.ai if you think this is a mistake." @@ -1009,7 +1018,7 @@ def render_report_form(self): send_reported_run_email( user=self.request.user, - run_uid=str(self.run_user.uid), + run_uid=str(self.current_sr_user.uid), url=self.current_app_url(), recipe_name=self.title, report_type=report_type, @@ -1052,11 +1061,22 @@ def update_flag_for_run(self, is_flagged: bool): sr.save(update_fields=["is_flagged"]) gui.session_state["is_flagged"] = is_flagged - @property + @cached_property + def current_sr_user(self) -> AppUser | None: + if not self.current_sr.uid: + return None + if self.request.user and self.request.user.uid == self.current_sr.uid: + return self.request.user + try: + return AppUser.objects.get(uid=self.current_sr.uid) + except AppUser.DoesNotExist: + return None + + @cached_property def current_sr(self) -> SavedRun: return self.current_sr_pr[0] - @property + @cached_property def current_pr(self) -> PublishedRun: return self.current_sr_pr[1] @@ -1571,7 +1591,7 @@ def create_new_run( uid = self.request.user.uid else: uid = auth.create_user().uid - self.request.scope["user"] = AppUser.objects.create( + self.request.user = AppUser.objects.create( uid=uid, is_anonymous=True, balance=settings.ANON_USER_FREE_CREDITS ) self.request.session[ANONYMOUS_USER_COOKIE] = dict(uid=uid) @@ -2138,7 +2158,7 @@ def is_current_user_paying(self) -> bool: return bool(self.request.user and self.request.user.is_paying) def is_current_user_owner(self) -> bool: - return bool(self.request.user and self.run_user == self.request.user) + return bool(self.request.user and self.current_sr_user == self.request.user) def started_at_text(dt: datetime.datetime): diff --git a/recipes/LipsyncTTS.py b/recipes/LipsyncTTS.py index 995f5d43f..d557cd663 100644 --- a/recipes/LipsyncTTS.py +++ b/recipes/LipsyncTTS.py @@ -122,12 +122,10 @@ def run(self, state: dict) -> typing.Iterator[str | None]: if not self.request.user.disable_safety_checker: safety_checker(text=state["text_prompt"]) - yield from TextToSpeechPage(request=self.request, run_user=self.run_user).run( - state - ) + yield from TextToSpeechPage(request=self.request).run(state) # IMP: Copy output of TextToSpeechPage "audio_url" to Lipsync as "input_audio" state["input_audio"] = state["audio_url"] - yield from LipsyncPage(request=self.request, run_user=self.run_user).run(state) + yield from LipsyncPage(request=self.request).run(state) def render_example(self, state: dict): output_video = state.get("output_video") diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 26b7e785e..77b266c26 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -1015,9 +1015,7 @@ def run_v2( tts_state = TextToSpeechPage.RequestModel.parse_obj( {**gui.session_state, "text_prompt": text} ).dict() - yield from TextToSpeechPage( - request=self.request, run_user=self.run_user - ).run(tts_state) + yield from TextToSpeechPage(request=self.request).run(tts_state) response.output_audio.append(tts_state["audio_url"]) if not request.input_face: @@ -1031,9 +1029,7 @@ def run_v2( "selected_model": request.lipsync_model, } ).dict() - yield from LipsyncPage(request=self.request, run_user=self.run_user).run( - lip_state - ) + yield from LipsyncPage(request=self.request).run(lip_state) response.output_video.append(lip_state["output_video"]) def get_tabs(self): diff --git a/routers/api.py b/routers/api.py index 68ad4c2b6..73c5ec7b9 100644 --- a/routers/api.py +++ b/routers/api.py @@ -3,7 +3,6 @@ import os.path import os.path import typing -from types import SimpleNamespace import gooey_gui as gui from fastapi import Depends @@ -262,11 +261,7 @@ def get_run_status( user: AppUser = Depends(api_auth_header), ): # init a new page for every request - self = page_cls( - request=SimpleNamespace( - user=user, query_params=dict(run_id=run_id, uid=user.uid) - ) - ) + self = page_cls(user=user, query_params=dict(run_id=run_id, uid=user.uid)) sr = self.current_sr web_url = str(furl(self.app_url(run_id=run_id, uid=user.uid))) ret = { @@ -344,9 +339,7 @@ def submit_api_call( ) -> tuple["celery.result.AsyncResult", "SavedRun"]: # init a new page for every request query_params.setdefault("uid", current_user.uid) - page = page_cls( - request=SimpleNamespace(user=current_user, query_params=query_params) - ) + page = page_cls(user=current_user, query_params=query_params) # get saved state from db state = page.current_sr_to_session_state() diff --git a/routers/root.py b/routers/root.py index 9f678416b..40884fd55 100644 --- a/routers/root.py +++ b/routers/root.py @@ -242,7 +242,7 @@ def explore_page(request: Request): @gui.route(app, "/api/") def api_docs_page(request: Request): with page_wrapper(request): - _api_docs_page(request) + _api_docs_page() return dict( meta=raw_build_meta_tags( url=get_og_url_path(request), @@ -255,7 +255,7 @@ def api_docs_page(request: Request): ) -def _api_docs_page(request): +def _api_docs_page(): from daras_ai_v2.all_pages import all_api_pages api_docs_url = str(furl(settings.API_BASE_URL) / "docs") @@ -312,7 +312,7 @@ def _api_docs_page(request): as_async = gui.checkbox("Run Async") as_form_data = gui.checkbox("Upload Files via Form Data") - page = workflow.page_cls(request=request) + page = workflow.page_cls() state = page.get_root_pr().saved_run.to_dict() api_url, request_body = page.get_example_request(state, include_all=include_all) response_body = page.get_example_response_body( @@ -667,11 +667,13 @@ def render_recipe_page( ) return RedirectResponse(str(new_url.set(origin=None)), status_code=301) - # this is because the code still expects example_id to be in the query params - request._query_params = dict(request.query_params) | dict(example_id=example_id) - - page = page_cls(tab=tab, request=request) - page.run_user = get_run_user(request, page.current_sr.uid) + page = page_cls( + tab=tab, + user=request.user, + request_session=request.session, + # this is because the code still expects example_id to be in the query params + query_params=dict(request.query_params) | dict(example_id=example_id), + ) if not gui.session_state: gui.session_state.update(page.current_sr_to_session_state()) @@ -692,17 +694,6 @@ def get_og_url_path(request) -> str: ) -def get_run_user(request: Request, uid: str) -> AppUser | None: - if not uid: - return - if request.user and request.user.uid == uid: - return request.user - try: - return AppUser.objects.get(uid=uid) - except AppUser.DoesNotExist: - pass - - @contextmanager def page_wrapper(request: Request, className=""): context = { diff --git a/tests/test_pricing.py b/tests/test_pricing.py index 73e05e4a9..c78bf5376 100644 --- a/tests/test_pricing.py +++ b/tests/test_pricing.py @@ -1,6 +1,3 @@ -from types import SimpleNamespace - -import gooey_gui as gui import pytest from starlette.testclient import TestClient @@ -49,10 +46,8 @@ def test_copilot_get_raw_price_round_up(): dollar_amount=model_pricing.unit_cost * 1 / model_pricing.unit_quantity, ) copilot_page = VideoBotsPage( - request=SimpleNamespace( - user=user, - query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""), - ), + user=user, + query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""), ) assert ( copilot_page.get_price_roundoff(state=state) @@ -114,10 +109,8 @@ def test_multiple_llm_sums_usage_cost(): ) llm_page = CompareLLMPage( - request=SimpleNamespace( - user=user, - query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""), - ) + user=user, + query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""), ) assert llm_page.get_price_roundoff(state=state) == (310 + llm_page.PROFIT_CREDITS) @@ -163,10 +156,8 @@ def test_workflowmetadata_2x_multiplier(): metadata.save() llm_page = CompareLLMPage( - request=SimpleNamespace( - user=user, - query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""), - ) + user=user, + query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""), ) assert ( llm_page.get_price_roundoff(state=state) == (210 + llm_page.PROFIT_CREDITS) * 2