Skip to content

Commit

Permalink
Remove usage of SimpleNamespace for request handling in BasePage
Browse files Browse the repository at this point in the history
Move BasePage.run_user -> cached property current_sr_user
  • Loading branch information
devxpy committed Sep 6, 2024
1 parent 5524cc0 commit 76b2051
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 80 deletions.
7 changes: 2 additions & 5 deletions bots/admin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import datetime
import json
from types import SimpleNamespace

import django.db.models
from django import forms
Expand Down Expand Up @@ -439,10 +438,8 @@ def rerun_tasks(self, request, queryset):
sr: SavedRun
for sr in queryset.all():
page = Workflow(sr.workflow).page_cls(
request=SimpleNamespace(
user=AppUser.objects.get(uid=sr.uid),
query_params=dict(run_id=sr.run_id, uid=sr.uid),
)
user=AppUser.objects.get(uid=sr.uid),
query_params=dict(run_id=sr.run_id, uid=sr.uid),
)
page.call_runner_task(sr, deduct_credits=False)
self.message_user(
Expand Down
6 changes: 1 addition & 5 deletions celeryapp/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import traceback
import typing
from time import time
from types import SimpleNamespace

import gooey_gui as gui
import requests
Expand Down Expand Up @@ -92,10 +91,7 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False
page.dump_state_to_sr(gui.session_state | output, sr)

page = page_cls(
request=SimpleNamespace(
user=AppUser.objects.get(id=user_id),
query_params=dict(run_id=run_id, uid=uid),
),
user=AppUser.objects.get(id=user_id), query_params=dict(run_id=run_id, uid=uid)
)
page.setup_sentry()
sr = page.current_sr
Expand Down
54 changes: 37 additions & 17 deletions daras_ai_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from itertools import pairwise
from random import Random
from time import sleep
from types import SimpleNamespace

import gooey_gui as gui
import sentry_sdk
Expand All @@ -26,7 +25,6 @@
from sentry_sdk.tracing import (
TRANSACTION_SOURCE_ROUTE,
)
from starlette.requests import Request

from app_users.models import AppUser, AppUserTransaction
from bots.models import (
Expand Down Expand Up @@ -94,7 +92,6 @@
MAX_SEED = 4294967294
gooey_rng = Random()


SUBMIT_AFTER_LOGIN_Q = "submitafterlogin"


Expand All @@ -117,6 +114,12 @@ class StateKeys:
hidden = "__hidden"


class BasePageRequest:
user: AppUser | None
session: dict
query_params: dict


class BasePage:
title: str
workflow: Workflow
Expand Down Expand Up @@ -154,14 +157,20 @@ def __init__(
self,
*,
tab: RecipeTabs = RecipeTabs.run,
request: Request | SimpleNamespace | None = None,
run_user: AppUser | None = None,
request: BasePageRequest | None = None,
user: AppUser | None = None,
request_session: dict | None = None,
query_params: dict | None = None,
):
if request is None:
request = SimpleNamespace(user=None, query_params={})
self.tab = tab

if not request:
request = BasePageRequest()
request.user = user
request.session = request_session or {}
request.query_params = query_params or {}

self.request = request
self.run_user = run_user

@classmethod
def api_endpoint(cls) -> str:
Expand Down Expand Up @@ -349,7 +358,7 @@ def _render_header(self):

with gui.div(className="d-flex justify-content-between mt-4"):
with gui.div(className="d-lg-flex d-block align-items-center"):
if not tbreadcrumbs.has_breadcrumbs() and not self.run_user:
if not tbreadcrumbs.has_breadcrumbs() and not self.current_sr_user:
self._render_title(tbreadcrumbs.h1_title)

if tbreadcrumbs:
Expand All @@ -362,7 +371,7 @@ def _render_header(self):
if is_example:
author = pr.created_by
else:
author = self.run_user or sr.get_creator()
author = self.current_sr_user or sr.get_creator()
if not is_root_example:
self.render_author(author)

Expand All @@ -386,7 +395,7 @@ def _render_header(self):
self._render_published_run_save_buttons(sr=sr, pr=pr)
self._render_social_buttons(show_button_text=not show_save_buttons)

if tbreadcrumbs.has_breadcrumbs() or self.run_user:
if tbreadcrumbs.has_breadcrumbs() or self.current_sr_user:
# only render title here if the above row was not empty
self._render_title(tbreadcrumbs.h1_title)

Expand Down Expand Up @@ -810,7 +819,7 @@ def get_explore_image(self) -> str:
return meta_preview_url(img, fallback_img)

def _user_disabled_check(self):
if self.run_user and self.run_user.is_disabled:
if self.current_sr_user and self.current_sr_user.is_disabled:
msg = (
"This Gooey.AI account has been disabled for violating our [Terms of Service](/terms). "
"Contact us at [email protected] if you think this is a mistake."
Expand Down Expand Up @@ -1009,7 +1018,7 @@ def render_report_form(self):

send_reported_run_email(
user=self.request.user,
run_uid=str(self.run_user.uid),
run_uid=str(self.current_sr_user.uid),
url=self.current_app_url(),
recipe_name=self.title,
report_type=report_type,
Expand Down Expand Up @@ -1052,11 +1061,22 @@ def update_flag_for_run(self, is_flagged: bool):
sr.save(update_fields=["is_flagged"])
gui.session_state["is_flagged"] = is_flagged

@property
@cached_property
def current_sr_user(self) -> AppUser | None:
if not self.current_sr.uid:
return None
if self.request.user and self.request.user.uid == self.current_sr.uid:
return self.request.user
try:
return AppUser.objects.get(uid=self.current_sr.uid)
except AppUser.DoesNotExist:
return None

@cached_property
def current_sr(self) -> SavedRun:
return self.current_sr_pr[0]

@property
@cached_property
def current_pr(self) -> PublishedRun:
return self.current_sr_pr[1]

Expand Down Expand Up @@ -1571,7 +1591,7 @@ def create_new_run(
uid = self.request.user.uid
else:
uid = auth.create_user().uid
self.request.scope["user"] = AppUser.objects.create(
self.request.user = AppUser.objects.create(
uid=uid, is_anonymous=True, balance=settings.ANON_USER_FREE_CREDITS
)
self.request.session[ANONYMOUS_USER_COOKIE] = dict(uid=uid)
Expand Down Expand Up @@ -2138,7 +2158,7 @@ def is_current_user_paying(self) -> bool:
return bool(self.request.user and self.request.user.is_paying)

def is_current_user_owner(self) -> bool:
return bool(self.request.user and self.run_user == self.request.user)
return bool(self.request.user and self.current_sr_user == self.request.user)


def started_at_text(dt: datetime.datetime):
Expand Down
6 changes: 2 additions & 4 deletions recipes/LipsyncTTS.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,10 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
if not self.request.user.disable_safety_checker:
safety_checker(text=state["text_prompt"])

yield from TextToSpeechPage(request=self.request, run_user=self.run_user).run(
state
)
yield from TextToSpeechPage(request=self.request).run(state)
# IMP: Copy output of TextToSpeechPage "audio_url" to Lipsync as "input_audio"
state["input_audio"] = state["audio_url"]
yield from LipsyncPage(request=self.request, run_user=self.run_user).run(state)
yield from LipsyncPage(request=self.request).run(state)

def render_example(self, state: dict):
output_video = state.get("output_video")
Expand Down
8 changes: 2 additions & 6 deletions recipes/VideoBots.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,9 +1015,7 @@ def run_v2(
tts_state = TextToSpeechPage.RequestModel.parse_obj(
{**gui.session_state, "text_prompt": text}
).dict()
yield from TextToSpeechPage(
request=self.request, run_user=self.run_user
).run(tts_state)
yield from TextToSpeechPage(request=self.request).run(tts_state)
response.output_audio.append(tts_state["audio_url"])

if not request.input_face:
Expand All @@ -1031,9 +1029,7 @@ def run_v2(
"selected_model": request.lipsync_model,
}
).dict()
yield from LipsyncPage(request=self.request, run_user=self.run_user).run(
lip_state
)
yield from LipsyncPage(request=self.request).run(lip_state)
response.output_video.append(lip_state["output_video"])

def get_tabs(self):
Expand Down
11 changes: 2 additions & 9 deletions routers/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os.path
import os.path
import typing
from types import SimpleNamespace

import gooey_gui as gui
from fastapi import Depends
Expand Down Expand Up @@ -262,11 +261,7 @@ def get_run_status(
user: AppUser = Depends(api_auth_header),
):
# init a new page for every request
self = page_cls(
request=SimpleNamespace(
user=user, query_params=dict(run_id=run_id, uid=user.uid)
)
)
self = page_cls(user=user, query_params=dict(run_id=run_id, uid=user.uid))
sr = self.current_sr
web_url = str(furl(self.app_url(run_id=run_id, uid=user.uid)))
ret = {
Expand Down Expand Up @@ -344,9 +339,7 @@ def submit_api_call(
) -> tuple["celery.result.AsyncResult", "SavedRun"]:
# init a new page for every request
query_params.setdefault("uid", current_user.uid)
page = page_cls(
request=SimpleNamespace(user=current_user, query_params=query_params)
)
page = page_cls(user=current_user, query_params=query_params)

# get saved state from db
state = page.current_sr_to_session_state()
Expand Down
29 changes: 10 additions & 19 deletions routers/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def explore_page(request: Request):
@gui.route(app, "/api/")
def api_docs_page(request: Request):
with page_wrapper(request):
_api_docs_page(request)
_api_docs_page()
return dict(
meta=raw_build_meta_tags(
url=get_og_url_path(request),
Expand All @@ -255,7 +255,7 @@ def api_docs_page(request: Request):
)


def _api_docs_page(request):
def _api_docs_page():
from daras_ai_v2.all_pages import all_api_pages

api_docs_url = str(furl(settings.API_BASE_URL) / "docs")
Expand Down Expand Up @@ -312,7 +312,7 @@ def _api_docs_page(request):
as_async = gui.checkbox("Run Async")
as_form_data = gui.checkbox("Upload Files via Form Data")

page = workflow.page_cls(request=request)
page = workflow.page_cls()
state = page.get_root_pr().saved_run.to_dict()
api_url, request_body = page.get_example_request(state, include_all=include_all)
response_body = page.get_example_response_body(
Expand Down Expand Up @@ -667,11 +667,13 @@ def render_recipe_page(
)
return RedirectResponse(str(new_url.set(origin=None)), status_code=301)

# this is because the code still expects example_id to be in the query params
request._query_params = dict(request.query_params) | dict(example_id=example_id)

page = page_cls(tab=tab, request=request)
page.run_user = get_run_user(request, page.current_sr.uid)
page = page_cls(
tab=tab,
user=request.user,
request_session=request.session,
# this is because the code still expects example_id to be in the query params
query_params=dict(request.query_params) | dict(example_id=example_id),
)

if not gui.session_state:
gui.session_state.update(page.current_sr_to_session_state())
Expand All @@ -692,17 +694,6 @@ def get_og_url_path(request) -> str:
)


def get_run_user(request: Request, uid: str) -> AppUser | None:
if not uid:
return
if request.user and request.user.uid == uid:
return request.user
try:
return AppUser.objects.get(uid=uid)
except AppUser.DoesNotExist:
pass


@contextmanager
def page_wrapper(request: Request, className=""):
context = {
Expand Down
21 changes: 6 additions & 15 deletions tests/test_pricing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from types import SimpleNamespace

import gooey_gui as gui
import pytest
from starlette.testclient import TestClient

Expand Down Expand Up @@ -49,10 +46,8 @@ def test_copilot_get_raw_price_round_up():
dollar_amount=model_pricing.unit_cost * 1 / model_pricing.unit_quantity,
)
copilot_page = VideoBotsPage(
request=SimpleNamespace(
user=user,
query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""),
),
user=user,
query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""),
)
assert (
copilot_page.get_price_roundoff(state=state)
Expand Down Expand Up @@ -114,10 +109,8 @@ def test_multiple_llm_sums_usage_cost():
)

llm_page = CompareLLMPage(
request=SimpleNamespace(
user=user,
query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""),
)
user=user,
query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""),
)
assert llm_page.get_price_roundoff(state=state) == (310 + llm_page.PROFIT_CREDITS)

Expand Down Expand Up @@ -163,10 +156,8 @@ def test_workflowmetadata_2x_multiplier():
metadata.save()

llm_page = CompareLLMPage(
request=SimpleNamespace(
user=user,
query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""),
)
user=user,
query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""),
)
assert (
llm_page.get_price_roundoff(state=state) == (210 + llm_page.PROFIT_CREDITS) * 2
Expand Down

0 comments on commit 76b2051

Please sign in to comment.