From 99d644b7f1629de66e3c044f09e3ed6ce6d0d98d Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Thu, 29 Aug 2024 18:53:17 +0530 Subject: [PATCH 1/4] Refactor `BasePage` methods to consolidate `SavedRun` and `PublishedRun` retrieval logic into `get_sr_pr` Remove usage of global gui.get_query_params --- bots/admin.py | 5 +- bots/models.py | 12 +- celeryapp/tasks.py | 23 +- conftest.py | 4 +- daras_ai_v2/base.py | 344 ++++++++------------- daras_ai_v2/bot_integration_widgets.py | 10 +- daras_ai_v2/bots.py | 2 +- daras_ai_v2/doc_search_settings_widgets.py | 2 +- daras_ai_v2/meta_content.py | 5 +- daras_ai_v2/safety_checker.py | 2 +- daras_ai_v2/workflow_url_input.py | 4 +- explore.py | 2 +- recipes/DocSearch.py | 4 +- recipes/Functions.py | 8 +- recipes/GoogleGPT.py | 2 +- recipes/VideoBots.py | 11 +- recipes/VideoBotsStats.py | 7 +- routers/api.py | 21 +- routers/bots_api.py | 2 +- routers/root.py | 21 +- routers/twilio_api.py | 2 +- tests/test_apis.py | 4 +- tests/test_pricing.py | 26 +- url_shortener/models.py | 13 +- usage_costs/cost_utils.py | 13 +- 25 files changed, 239 insertions(+), 310 deletions(-) diff --git a/bots/admin.py b/bots/admin.py index eaf200b04..b7e696dfd 100644 --- a/bots/admin.py +++ b/bots/admin.py @@ -439,7 +439,10 @@ def rerun_tasks(self, request, queryset): sr: SavedRun for sr in queryset.all(): page = Workflow(sr.workflow).page_cls( - request=SimpleNamespace(user=AppUser.objects.get(uid=sr.uid)) + request=SimpleNamespace( + user=AppUser.objects.get(uid=sr.uid), + query_params=dict(run_id=sr.run_id, uid=sr.uid), + ) ) page.call_runner_task(sr, deduct_credits=False) self.message_user( diff --git a/bots/models.py b/bots/models.py index e997e8f8a..3ace6bc33 100644 --- a/bots/models.py +++ b/bots/models.py @@ -127,16 +127,12 @@ def get_or_create_metadata(self) -> "WorkflowMetadata": workflow=self, create=lambda **kwargs: WorkflowMetadata.objects.create( **kwargs, - short_title=( - self.page_cls.get_root_published_run().title or self.page_cls.title - ), + short_title=(self.page_cls.get_root_pr().title or self.page_cls.title), default_image=self.page_cls.explore_image or "", - meta_title=( - self.page_cls.get_root_published_run().title or self.page_cls.title - ), + meta_title=(self.page_cls.get_root_pr().title or self.page_cls.title), meta_description=( self.page_cls().preview_description(state={}) - or self.page_cls.get_root_published_run().notes + or self.page_cls.get_root_pr().notes ), meta_image=self.page_cls.explore_image or "", ), @@ -389,7 +385,7 @@ def submit_api_call( ), ) - return result, page.run_doc_sr(run_id, uid) + return result, page.current_sr def get_creator(self) -> AppUser | None: if self.uid: diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index c3ae75b52..b2e257365 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -1,5 +1,6 @@ import datetime import html +import threading import traceback import typing from time import time @@ -31,6 +32,15 @@ DEFAULT_RUN_STATUS = "Running..." +threadlocal = threading.local() + + +def get_running_saved_run() -> SavedRun | None: + try: + return threadlocal.saved_run + except AttributeError: + return None + @app.task def runner_task( @@ -81,12 +91,16 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False # save to db page.dump_state_to_sr(gui.session_state | output, sr) - user = AppUser.objects.get(id=user_id) - page = page_cls(request=SimpleNamespace(user=user)) + page = page_cls( + request=SimpleNamespace( + user=AppUser.objects.get(id=user_id), + query_params=dict(run_id=run_id, uid=uid), + ), + ) page.setup_sentry() - sr = page.run_doc_sr(run_id, uid) + sr = page.current_sr + threadlocal.saved_run = sr gui.set_session_state(sr.to_dict() | (unsaved_state or {})) - gui.set_query_params(dict(run_id=run_id, uid=uid)) try: save_on_step() @@ -114,6 +128,7 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False # save everything, mark run as completed finally: save_on_step(done=True) + threadlocal.saved_run = None post_runner_tasks.delay(sr.id) diff --git a/conftest.py b/conftest.py index 57d5d5cca..539ea848e 100644 --- a/conftest.py +++ b/conftest.py @@ -64,10 +64,10 @@ def mock_celery_tasks(): def _mock_runner_task( *, page_cls: typing.Type[BasePage], run_id: str, uid: str, **kwargs ): - sr = page_cls.run_doc_sr(run_id, uid) + sr = page_cls.get_sr_from_ids(run_id, uid) sr.set(sr.parent.to_dict()) sr.save() - channel = page_cls().realtime_channel_name(run_id, uid) + channel = page_cls.realtime_channel_name(run_id, uid) _mock_realtime_push(channel, sr.to_dict()) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 3ab715ea2..d1f3715c1 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -8,6 +8,7 @@ import uuid from copy import deepcopy, copy from enum import Enum +from functools import cached_property from itertools import pairwise from random import Random from time import sleep @@ -151,10 +152,13 @@ class RequestModel(BaseModel): def __init__( self, - tab: RecipeTabs = "", - request: Request | SimpleNamespace = None, - run_user: AppUser = None, + *, + tab: RecipeTabs = RecipeTabs.run, + request: Request | SimpleNamespace | None = None, + run_user: AppUser | None = None, ): + if request is None: + request = SimpleNamespace(user=None, query_params={}) self.tab = tab self.request = request self.run_user = run_user @@ -164,9 +168,8 @@ def __init__( def endpoint(cls) -> str: return f"/v2/{cls.slug_versions[0]}" - @classmethod def current_app_url( - cls, + self, tab: RecipeTabs = RecipeTabs.run, *, query_params: dict = None, @@ -174,8 +177,8 @@ def current_app_url( ) -> str: if query_params is None: query_params = {} - example_id, run_id, uid = extract_query_params(gui.get_query_params()) - return cls.app_url( + example_id, run_id, uid = extract_query_params(self.request.query_params) + return self.app_url( tab=tab, example_id=example_id, run_id=run_id, @@ -209,7 +212,7 @@ def app_url( run_slug = None if example_id: try: - pr = cls.get_published_run(published_run_id=example_id) + pr = cls.get_pr_from_example_id(example_id=example_id) except PublishedRun.DoesNotExist: pr = None if pr and pr.title: @@ -225,11 +228,6 @@ def app_url( ) ) - @classmethod - def current_api_url(cls) -> furl | None: - pr = cls.get_current_published_run() - return cls.api_url(example_id=pr and pr.published_run_id) - @classmethod def api_url( cls, @@ -276,12 +274,12 @@ def sentry_event_set_request(self, event, hint): ) else: request["url"] = self.app_url( - tab=self.tab, query_params=gui.get_query_params() + tab=self.tab, query_params=dict(self.request.query_params) ) return event def sentry_event_set_user(self, event, hint): - if user := self.request and self.request.user: + if user := self.request.user: event["user"] = { "id": user.id, "name": user.display_name, @@ -305,7 +303,7 @@ def sentry_event_set_user(self, event, hint): return event def refresh_state(self): - sr = self.get_current_sr() + sr = self.current_sr channel = self.realtime_channel_name(sr.run_id, sr.uid) output = gui.realtime_pull([channel])[0] if output: @@ -341,14 +339,11 @@ def render(self): self._render_header() def _render_header(self): - current_run = self.get_current_sr() - published_run = self.get_current_published_run() - is_example = published_run.saved_run == current_run - is_root_example = is_example and published_run.is_root() - tbreadcrumbs = get_title_breadcrumbs( - self, current_run, published_run, tab=self.tab - ) - can_save = self.can_user_save_run(current_run, published_run) + sr, pr = self.current_sr_pr + is_example = pr.saved_run == sr + is_root_example = is_example and pr.is_root() + tbreadcrumbs = get_title_breadcrumbs(self, sr, pr, tab=self.tab) + can_save = self.can_user_save_run(sr, pr) request_changed = self._has_request_changed() with gui.div(className="d-flex justify-content-between mt-4"): @@ -360,15 +355,13 @@ def _render_header(self): with gui.tag("div", className="me-3 mb-1 mb-lg-0 py-2 py-lg-0"): render_breadcrumbs( tbreadcrumbs, - is_api_call=( - current_run.is_api_call and self.tab == RecipeTabs.run - ), + is_api_call=(sr.is_api_call and self.tab == RecipeTabs.run), ) if is_example: - author = published_run.created_by + author = pr.created_by else: - author = self.run_user or current_run.get_creator() + author = self.run_user or sr.get_creator() if not is_root_example: self.render_author(author) @@ -389,10 +382,7 @@ def _render_header(self): show_save_buttons = request_changed or can_save if show_save_buttons: - self._render_published_run_save_buttons( - current_run=current_run, - published_run=published_run, - ) + self._render_published_run_save_buttons(sr=sr, pr=pr) self._render_social_buttons(show_button_text=not show_save_buttons) if tbreadcrumbs.has_breadcrumbs() or self.run_user: @@ -401,15 +391,15 @@ def _render_header(self): if self.tab != RecipeTabs.run: return - if published_run and published_run.notes: - gui.write(published_run.notes, line_clamp=2) + if pr and pr.notes: + gui.write(pr.notes, line_clamp=2) elif is_root_example and self.tab != RecipeTabs.integrations: - gui.write(self.preview_description(current_run.to_dict()), line_clamp=2) + gui.write(self.preview_description(sr.to_dict()), line_clamp=2) def can_user_save_run( self, current_run: SavedRun, - published_run: PublishedRun | None, + published_run: PublishedRun, ) -> bool: return ( self.is_current_user_admin() @@ -425,13 +415,9 @@ def can_user_save_run( ) ) - def can_user_edit_published_run( - self, published_run: PublishedRun | None = None - ) -> bool: - published_run = published_run or self.get_current_published_run() + def can_user_edit_published_run(self, published_run: PublishedRun) -> bool: return self.is_current_user_admin() or bool( - published_run - and self.request + self.request and self.request.user and published_run.created_by_id and published_run.created_by_id == self.request.user.id @@ -460,13 +446,8 @@ def _render_social_buttons(self, show_button_text: bool = False): className="mb-0 ms-lg-2", ) - def _render_published_run_save_buttons( - self, - *, - current_run: SavedRun, - published_run: PublishedRun, - ): - can_edit = self.can_user_edit_published_run(published_run) + def _render_published_run_save_buttons(self, *, sr: SavedRun, pr: PublishedRun): + can_edit = self.can_user_edit_published_run(pr) with gui.div(className="d-flex justify-content-end"): gui.html( @@ -497,8 +478,8 @@ def _render_published_run_save_buttons( if options_modal.is_open(): with options_modal.container(style={"minWidth": "min(300px, 100vw)"}): self._render_options_modal( - current_run=current_run, - published_run=published_run, + current_run=sr, + published_run=pr, modal=options_modal, ) @@ -518,8 +499,8 @@ def _render_published_run_save_buttons( if publish_modal.is_open(): with publish_modal.container(style={"minWidth": "min(500px, 100vw)"}): self._render_publish_modal( - current_run=current_run, - published_run=published_run, + sr=sr, + pr=pr, modal=publish_modal, is_update_mode=can_edit, ) @@ -527,12 +508,12 @@ def _render_published_run_save_buttons( def _render_publish_modal( self, *, - current_run: SavedRun, - published_run: PublishedRun, + sr: SavedRun, + pr: PublishedRun, modal: gui.Modal, is_update_mode: bool = False, ): - if published_run.is_root() and self.is_current_user_admin(): + if pr.is_root() and self.is_current_user_admin(): with gui.div(className="text-danger"): gui.write( "###### You're about to update the root workflow as an admin. " @@ -564,7 +545,7 @@ def _render_publish_modal( "", options=options, format_func=options.__getitem__, - value=str(published_run.visibility), + value=str(pr.visibility), ) ) ) @@ -579,9 +560,9 @@ def _render_publish_modal( with gui.div(className="mt-4"): if is_update_mode: - title = published_run.title or self.title + title = pr.title or self.title else: - recipe_title = self.get_root_published_run().title or self.title + recipe_title = self.get_root_pr().title or self.title title = f"{self.request.user.first_name_possesive()} {recipe_title}" published_run_title = gui.text_input( "##### Title", @@ -591,11 +572,7 @@ def _render_publish_modal( published_run_notes = gui.text_area( "##### Notes", key="published_run_notes", - value=( - published_run.notes - or self.preview_description(gui.session_state) - or "" - ), + value=(pr.notes or self.preview_description(gui.session_state) or ""), ) with gui.div(className="mt-4 d-flex justify-content-center"): @@ -605,12 +582,12 @@ def _render_publish_modal( type="primary", ) - self._render_admin_options(current_run, published_run) + self._render_admin_options(sr, pr) if not pressed_save: return - is_root_published_run = is_update_mode and published_run.is_root() + is_root_published_run = is_update_mode and pr.is_root() if not is_root_published_run: try: self._validate_published_run_title(published_run_title) @@ -619,33 +596,31 @@ def _render_publish_modal( return if self._has_request_changed(): - current_run = self.on_submit() - if not current_run: + sr = self.on_submit() + if not sr: modal.close() if is_update_mode: updates = dict( - saved_run=current_run, + saved_run=sr, title=published_run_title.strip(), notes=published_run_notes.strip(), visibility=published_run_visibility, ) - if not self._has_published_run_changed( - published_run=published_run, **updates - ): + if not self._has_published_run_changed(published_run=pr, **updates): gui.error("No changes to publish", icon="⚠️") return - published_run.add_version(user=self.request.user, **updates) + pr.add_version(user=self.request.user, **updates) else: - published_run = self.create_published_run( + pr = self.create_published_run( published_run_id=get_random_doc_id(), - saved_run=current_run, + saved_run=sr, user=self.request.user, title=published_run_title.strip(), notes=published_run_notes.strip(), visibility=published_run_visibility, ) - raise gui.RedirectException(published_run.get_app_url()) + raise gui.RedirectException(pr.get_app_url()) def _validate_published_run_title(self, title: str): if slugify(title) in settings.DISALLOWED_TITLE_SLUGS: @@ -813,7 +788,7 @@ def _render_admin_options(self, current_run: SavedRun, published_run: PublishedR className="text-danger", ) if gui.button("👌 Yes, Update the Root Workflow"): - root_run = self.get_root_published_run() + root_run = self.get_root_pr() root_run.add_version( user=self.request.user, title=published_run.title, @@ -825,7 +800,7 @@ def _render_admin_options(self, current_run: SavedRun, published_run: PublishedR @classmethod def get_recipe_title(cls) -> str: - return cls.get_root_published_run().title or cls.title or cls.workflow.label + return cls.get_root_pr().title or cls.title or cls.workflow.label def get_explore_image(self) -> str: meta = self.workflow.get_or_create_metadata() @@ -853,7 +828,7 @@ def get_tabs(self): def render_selected_tab(self): match self.tab: case RecipeTabs.run: - if self.get_current_sr().retention_policy == RetentionPolicy.delete: + if self.current_sr.retention_policy == RetentionPolicy.delete: self.render_deleted_output() return @@ -884,15 +859,12 @@ def render_selected_tab(self): self._saved_tab() def _render_version_history(self): - published_run = self.get_current_published_run() - - if published_run: - versions = published_run.versions.all() - first_version = versions[0] - for version, older_version in pairwise(versions): - first_version = older_version - self._render_version_row(version, older_version) - self._render_version_row(first_version, None) + versions = self.current_pr.versions.all() + first_version = versions[0] + for version, older_version in pairwise(versions): + first_version = older_version + self._render_version_row(version, older_version) + self._render_version_row(first_version, None) def _render_version_row( self, @@ -957,7 +929,7 @@ def render_related_workflows(self): def _render(page_cls: typing.Type[BasePage]): page = page_cls() - root_run = page.get_root_published_run() + root_run = page.get_root_pr() state = root_run.saved_run.to_dict() preview_image = page.get_explore_image() @@ -1034,11 +1006,9 @@ def render_report_form(self): gui.error("Reason for report cannot be empty") return - example_id, run_id, uid = extract_query_params(gui.get_query_params()) - send_reported_run_email( user=self.request.user, - run_uid=uid, + run_uid=str(self.run_user.uid), url=self.current_app_url(), recipe_name=self.title, report_type=report_type, @@ -1047,7 +1017,7 @@ def render_report_form(self): ) if report_type == inappropriate_radio_text: - self.update_flag_for_run(run_id=run_id, uid=uid, is_flagged=True) + self.update_flag_for_run(is_flagged=True) # gui.success("Reported.") gui.session_state["show_report_workflow"] = False @@ -1064,10 +1034,8 @@ def _check_if_flagged(self): if not unflag_pressed: return with gui.spinner("Removing flag..."): - example_id, run_id, uid = extract_query_params(gui.get_query_params()) - if run_id and uid: - self.update_flag_for_run(run_id=run_id, uid=uid, is_flagged=False) - gui.success("Removed flag.", icon="✅") + self.update_flag_for_run(is_flagged=False) + gui.success("Removed flag.") sleep(2) gui.rerun() else: @@ -1077,87 +1045,47 @@ def _check_if_flagged(self): # Return and Don't render the run any further gui.stop() - @classmethod - def get_runs_from_query_params( - cls, example_id: str, run_id: str, uid: str - ) -> tuple[SavedRun, PublishedRun | None]: - if run_id and uid: - sr = cls.run_doc_sr(run_id, uid) - pr = sr.parent_published_run() - else: - pr = cls.get_published_run(published_run_id=example_id or "") - sr = pr.saved_run - return sr, pr + def update_flag_for_run(self, is_flagged: bool): + sr = self.current_sr + sr.is_flagged = is_flagged + sr.save(update_fields=["is_flagged"]) + gui.session_state["is_flagged"] = is_flagged - @classmethod - def get_current_published_run(cls) -> PublishedRun | None: - example_id, run_id, uid = extract_query_params(gui.get_query_params()) - return cls.get_pr_from_query_params(example_id, run_id, uid) + @property + def current_sr(self) -> SavedRun: + return self.current_sr_pr[0] - @classmethod - def get_pr_from_query_params( - cls, example_id: str, run_id: str, uid: str - ) -> PublishedRun | None: - if run_id and uid: - sr = cls.get_sr_from_query_params(example_id, run_id, uid) - return sr.parent_published_run() or cls.get_root_published_run() - elif example_id: - return cls.get_published_run(published_run_id=example_id) - else: - return cls.get_root_published_run() + @property + def current_pr(self) -> PublishedRun: + return self.current_sr_pr[1] - @classmethod - def get_published_run(cls, *, published_run_id: str): - return PublishedRun.objects.get( - workflow=cls.workflow, - published_run_id=published_run_id, + @cached_property + def current_sr_pr(self) -> tuple[SavedRun, PublishedRun]: + return self.get_sr_pr_from_query_params( + *extract_query_params(self.request.query_params) ) @classmethod - def get_current_sr(cls) -> SavedRun: - return cls.get_sr_from_query_params_dict(gui.get_query_params()) - - @classmethod - def get_sr_from_query_params_dict(cls, query_params) -> SavedRun: - example_id, run_id, uid = extract_query_params(query_params) - return cls.get_sr_from_query_params(example_id, run_id, uid) - - @classmethod - def get_sr_from_query_params( - cls, example_id: str | None, run_id: str | None, uid: str | None - ) -> SavedRun: - try: - if run_id and uid: - sr = cls.run_doc_sr(run_id, uid) - elif example_id: - pr = cls.get_published_run(published_run_id=example_id) - assert ( - pr.saved_run is not None - ), "invalid published run: without a saved run" - sr = pr.saved_run - else: - sr = cls.recipe_doc_sr() - return sr - except (SavedRun.DoesNotExist, PublishedRun.DoesNotExist): - raise HTTPException(status_code=404) - - @classmethod - def get_total_runs(cls) -> int: - # TODO: fix to also handle published run case - return SavedRun.objects.filter(workflow=cls.workflow).count() - - @classmethod - def recipe_doc_sr(cls, create: bool = True) -> SavedRun: - if create: - return cls.get_root_published_run().saved_run + def get_sr_pr_from_query_params( + cls, example_id: str, run_id: str, uid: str + ) -> tuple[SavedRun, PublishedRun]: + if run_id and uid: + sr = cls.get_sr_from_ids(run_id, uid) + pr = sr.parent_published_run() or cls.get_root_pr() else: - return cls.get_root_published_run().saved_run + if example_id: + pr = cls.get_pr_from_example_id(example_id=example_id) + else: + pr = cls.get_root_pr() + sr = pr.saved_run + return sr, pr @classmethod - def run_doc_sr( + def get_sr_from_ids( cls, run_id: str, uid: str, + *, create: bool = False, defaults: dict = None, ) -> SavedRun: @@ -1168,7 +1096,14 @@ def run_doc_sr( return SavedRun.objects.get(**config) @classmethod - def get_root_published_run(cls) -> PublishedRun: + def get_pr_from_example_id(cls, *, example_id: str): + return PublishedRun.objects.get( + workflow=cls.workflow, + published_run_id=example_id, + ) + + @classmethod + def get_root_pr(cls) -> PublishedRun: return PublishedRun.objects.get_or_create_with_version( workflow=cls.workflow, published_run_id="", @@ -1219,6 +1154,11 @@ def duplicate_published_run( visibility=visibility, ) + @classmethod + def get_total_runs(cls) -> int: + # TODO: fix to also handle published run case + return SavedRun.objects.filter(workflow=cls.workflow).count() + def render_description(self): pass @@ -1328,9 +1268,8 @@ def render_submit_button(self, key="--submit-1"): def render_run_cost(self): url = self.get_credits_click_url() - sr = self.get_current_sr() - if sr.price: - run_cost = sr.price + if self.current_sr.price: + run_cost = self.current_sr.price else: run_cost = self.get_price_roundoff(gui.session_state) ret = f'Run cost = {run_cost} credits' @@ -1356,7 +1295,7 @@ def _render_step_row(self): with col2: placeholder = gui.div() render_called_functions( - saved_run=self.get_current_sr(), trigger=FunctionTrigger.pre + saved_run=self.current_sr, trigger=FunctionTrigger.pre ) try: self.render_steps() @@ -1366,7 +1305,7 @@ def _render_step_row(self): with placeholder: gui.write("##### 👣 Steps") render_called_functions( - saved_run=self.get_current_sr(), trigger=FunctionTrigger.post + saved_run=self.current_sr, trigger=FunctionTrigger.post ) def _render_help(self): @@ -1447,9 +1386,10 @@ def run_v2( raise NotImplementedError def _render_report_button(self): - example_id, run_id, uid = extract_query_params(gui.get_query_params()) - # only logged in users can report a run (but not examples/default runs) - if not (self.request.user and run_id and uid): + sr, pr = self.current_sr_pr + is_example = pr.saved_run_id == sr.id + # only logged in users can report a run (but not examples/root runs) + if not self.request.user or is_example: return reported = gui.button( @@ -1461,12 +1401,6 @@ def _render_report_button(self): gui.session_state["show_report_workflow"] = reported gui.rerun() - def update_flag_for_run(self, run_id: str, uid: str, is_flagged: bool): - ref = self.run_doc_sr(uid=uid, run_id=run_id) - ref.is_flagged = is_flagged - ref.save(update_fields=["is_flagged"]) - gui.session_state["is_flagged"] = is_flagged - # Functions in every recipe feels like overkill for now, hide it in settings functions_in_settings = True show_settings = True @@ -1617,8 +1551,7 @@ def on_submit(self): def should_submit_after_login(self) -> bool: return ( - gui.get_query_params().get(SUBMIT_AFTER_LOGIN_Q) - and self.request + self.request.query_params.get(SUBMIT_AFTER_LOGIN_Q) and self.request.user and not self.request.user.is_anonymous ) @@ -1647,19 +1580,13 @@ def create_new_run( run_id = get_random_doc_id() - parent_example_id, parent_run_id, parent_uid = extract_query_params( - gui.get_query_params() - ) - parent = self.get_sr_from_query_params( - parent_example_id, parent_run_id, parent_uid - ) - published_run = self.get_current_published_run() + parent, pr = self.current_sr_pr try: - parent_version = published_run and published_run.versions.latest() + parent_version = pr.versions.latest() except PublishedRunVersion.DoesNotExist: parent_version = None - sr = self.run_doc_sr( + sr = self.get_sr_from_ids( run_id, uid, create=True, @@ -1697,7 +1624,7 @@ def call_runner_task(self, sr: SavedRun, deduct_credits: bool = True): ) @classmethod - def realtime_channel_name(cls, run_id, uid): + def realtime_channel_name(cls, run_id: str, uid: str) -> str: return f"gooey-outputs/{cls.slug_versions[0]}/{uid}/{run_id}" def generate_credit_error_message(self, run_id, uid) -> str: @@ -1849,7 +1776,7 @@ def _history_tab(self): if self.is_current_user_admin(): uid = self.request.query_params.get("uid", uid) - before = gui.get_query_params().get("updated_at__lt", None) + before = self.request.query_params.get("updated_at__lt", None) if before: before = datetime.datetime.fromisoformat(before) else: @@ -2051,11 +1978,10 @@ def run_as_api_tab(self): as_async = gui.checkbox("##### Run Async") as_form_data = gui.checkbox("##### Upload Files via Form Data") - pr = self.get_current_published_run() api_url, request_body = self.get_example_request( gui.session_state, include_all=include_all, - pr=pr, + pr=self.current_pr, ) response_body = self.get_example_response_body( gui.session_state, as_async=as_async, include_all=include_all @@ -2105,9 +2031,7 @@ def ensure_credits_and_auto_recharge(self, sr: SavedRun, state: dict): raise InsufficientCredits(self.request.user, sr) def deduct_credits(self, state: dict) -> tuple[AppUserTransaction, int]: - assert ( - self.request and self.request.user - ), "request.user must be set to deduct credits" + assert self.request.user, "request.user must be set to deduct credits" amount = self.get_price_roundoff(state) txn = self.request.user.add_balance(-amount, f"gooey_in_{uuid.uuid1()}") @@ -2124,7 +2048,7 @@ def get_raw_price(self, state: dict) -> float: def get_total_linked_usage_cost_in_credits(self, default=1): """Return the sum of the linked usage costs in gooey credits.""" - sr = self.get_current_sr() + sr = self.current_sr total = sr.usage_costs.aggregate(total=Sum("dollar_amount"))["total"] if not total: return default @@ -2132,10 +2056,8 @@ def get_total_linked_usage_cost_in_credits(self, default=1): def get_grouped_linked_usage_cost_in_credits(self): """Return the linked usage costs grouped by model name in gooey credits.""" - qs = ( - self.get_current_sr() - .usage_costs.values("pricing__model_name") - .annotate(total=Sum("dollar_amount") * settings.ADDON_CREDITS_PER_DOLLAR) + qs = self.current_sr.usage_costs.values("pricing__model_name").annotate( + total=Sum("dollar_amount") * settings.ADDON_CREDITS_PER_DOLLAR ) return {item["pricing__model_name"]: item["total"] for item in qs} @@ -2179,7 +2101,7 @@ def get_example_response_body( run_id=run_id, uid=self.request.user and self.request.user.uid, ) - sr = self.get_current_sr() + sr = self.current_sr output = sr.api_output(extract_model_fields(self.ResponseModel, state)) if as_async: return dict( @@ -2210,17 +2132,13 @@ def is_user_admin(cls, user: AppUser) -> bool: return email and email in settings.ADMIN_EMAILS def is_current_user_admin(self) -> bool: - if not self.request or not self.request.user: - return False - return self.is_user_admin(self.request.user) + return self.request.user and self.is_user_admin(self.request.user) def is_current_user_paying(self) -> bool: - return bool(self.request and self.request.user and self.request.user.is_paying) + return bool(self.request.user and self.request.user.is_paying) def is_current_user_owner(self) -> bool: - return bool( - self.request and self.request.user and self.run_user == self.request.user - ) + return bool(self.request.user and self.run_user == self.request.user) def started_at_text(dt: datetime.datetime): diff --git a/daras_ai_v2/bot_integration_widgets.py b/daras_ai_v2/bot_integration_widgets.py index 36cca09ec..8e767b351 100644 --- a/daras_ai_v2/bot_integration_widgets.py +++ b/daras_ai_v2/bot_integration_widgets.py @@ -7,6 +7,7 @@ from django.db import transaction from django.utils.text import slugify from furl import furl +from starlette.requests import Request from app_users.models import AppUser from bots.models import BotIntegration, BotIntegrationAnalysisRun, Platform @@ -54,7 +55,7 @@ def integrations_welcome_screen(title: str): gui.caption("Analyze your usage. Update your Saved Run to test changes.") -def general_integration_settings(bi: BotIntegration, current_user: AppUser): +def general_integration_settings(bi: BotIntegration, request: Request): if gui.session_state.get(f"_bi_reset_{bi.id}"): gui.session_state[f"_bi_streaming_enabled_{bi.id}"] = ( BotIntegration._meta.get_field("streaming_enabled").default @@ -101,9 +102,10 @@ def general_integration_settings(bi: BotIntegration, current_user: AppUser): "📊 View Results", str( furl( - VideoBotsPage.current_app_url( - RecipeTabs.integrations, + VideoBotsPage.app_url( + tab=RecipeTabs.integrations, path_params=dict(integration_id=bi.api_integration_id()), + query_params=dict(request.query_params), ) ) / "analysis/" @@ -119,7 +121,7 @@ def render_workflow_url_input(key: str, del_key: str | None, d: dict): key=key, internal_state=d, del_key=del_key, - current_user=current_user, + current_user=request.user, ) if not ret: return diff --git a/daras_ai_v2/bots.py b/daras_ai_v2/bots.py index 90fef04ee..eea020936 100644 --- a/daras_ai_v2/bots.py +++ b/daras_ai_v2/bots.py @@ -440,7 +440,7 @@ def _process_and_send_msg( # wait for the celery task to finish get_celery_result_db_safe(result) # get the final state from db - sr = page.run_doc_sr(run_id, uid) + sr = page.current_sr state = sr.to_dict() bot.recipe_run_state = page.get_run_state(state) bot.run_status = state.get(StateKeys.run_status) or "" diff --git a/daras_ai_v2/doc_search_settings_widgets.py b/daras_ai_v2/doc_search_settings_widgets.py index 9b46c9cdf..186075646 100644 --- a/daras_ai_v2/doc_search_settings_widgets.py +++ b/daras_ai_v2/doc_search_settings_widgets.py @@ -129,7 +129,7 @@ def doc_extract_selector(current_user: AppUser | None): gui.write("###### Create Synthetic Data") gui.caption( f""" - To improve answer quality, pick a [synthetic data maker workflow]({DocExtractPage.get_root_published_run().get_app_url()}) to scan & OCR any images in your documents or transcribe & translate any videos. It also can synthesize a helpful FAQ. Adds ~2 minutes of one-time processing per file. + To improve answer quality, pick a [synthetic data maker workflow]({DocExtractPage.get_root_pr().get_app_url()}) to scan & OCR any images in your documents or transcribe & translate any videos. It also can synthesize a helpful FAQ. Adds ~2 minutes of one-time processing per file. """ ) workflow_url_input( diff --git a/daras_ai_v2/meta_content.py b/daras_ai_v2/meta_content.py index c3335434b..d2c255a62 100644 --- a/daras_ai_v2/meta_content.py +++ b/daras_ai_v2/meta_content.py @@ -17,11 +17,8 @@ def build_meta_tags( url: str, page: "BasePage", state: dict, - run_id: str, - uid: str, - example_id: str, ) -> list[dict]: - sr, pr = page.get_runs_from_query_params(example_id, run_id, uid) + sr, pr = page.current_sr_pr metadata = page.workflow.get_or_create_metadata() title = meta_title_for_page( diff --git a/daras_ai_v2/safety_checker.py b/daras_ai_v2/safety_checker.py index 84c14baf5..8a8d67336 100644 --- a/daras_ai_v2/safety_checker.py +++ b/daras_ai_v2/safety_checker.py @@ -29,7 +29,7 @@ def safety_checker_text(text_input: str): # run in a thread to avoid messing up threadlocals result, sr = ( CompareLLMPage() - .get_published_run(published_run_id=settings.SAFETY_CHECKER_EXAMPLE_ID) + .get_pr_from_example_id(example_id=settings.SAFETY_CHECKER_EXAMPLE_ID) .submit_api_call( current_user=billing_account, request_body=dict(variables=dict(input=text_input)), diff --git a/daras_ai_v2/workflow_url_input.py b/daras_ai_v2/workflow_url_input.py index 0b82cd2a8..b9f23cc56 100644 --- a/daras_ai_v2/workflow_url_input.py +++ b/daras_ai_v2/workflow_url_input.py @@ -136,7 +136,7 @@ def url_to_runs( assert match, "Not a valid Gooey.AI URL" page_cls = page_slug_map[normalize_slug(match.matched_params["page_slug"])] example_id, run_id, uid = extract_query_params(furl(url).query.params) - sr, pr = page_cls.get_runs_from_query_params( + sr, pr = page_cls.get_sr_pr_from_query_params( example_id or match.matched_params.get("example_id"), run_id, uid ) return page_cls, sr, pr @@ -177,7 +177,7 @@ def get_published_run_options( if include_root: # include root recipe if requested options_dict = { - page_cls.get_root_published_run().get_app_url(): "Default", + page_cls.get_root_pr().get_app_url(): "Default", } | options_dict return options_dict diff --git a/explore.py b/explore.py index cd4ec3d00..5bbe380fc 100644 --- a/explore.py +++ b/explore.py @@ -85,7 +85,7 @@ def render_description(page: BasePage): with gui.link(to=page.app_url()): gui.markdown(f"#### {page.get_recipe_title()}") - root_pr = page.get_root_published_run() + root_pr = page.get_root_pr() notes = root_pr.notes or page.preview_description(state=page.sane_defaults) with gui.tag("p", style={"marginBottom": "25px"}): gui.write(notes, line_clamp=4) diff --git a/recipes/DocSearch.py b/recipes/DocSearch.py index 04e034dde..d0c49f3c7 100644 --- a/recipes/DocSearch.py +++ b/recipes/DocSearch.py @@ -136,7 +136,7 @@ def render_settings(self): gui.write("---") gui.write("##### 🔎 Document Search Settings") citation_style_selector() - doc_extract_selector(self.request and self.request.user) + doc_extract_selector(self.request.user) query_instructions_widget() gui.write("---") doc_search_advanced_settings() @@ -175,7 +175,7 @@ def run_v2( "search_query": response.final_search_query, }, ), - current_user=self.request and self.request.user, + current_user=self.request.user, ) # empty search result, abort! diff --git a/recipes/Functions.py b/recipes/Functions.py index 356381343..81b99c946 100644 --- a/recipes/Functions.py +++ b/recipes/Functions.py @@ -59,10 +59,8 @@ def run_v2( request: "FunctionsPage.RequestModel", response: "FunctionsPage.ResponseModel", ) -> typing.Iterator[str | None]: - query_params = gui.get_query_params() - run_id = query_params.get("run_id") - uid = query_params.get("uid") - tag = f"run_id={run_id}&uid={uid}" + sr = self.current_sr + tag = f"run_id={sr.run_id}&uid={sr.uid}" yield "Running your code..." # this will run functions/executor.js in deno deploy @@ -86,7 +84,7 @@ def render_form_v2(self): ) def get_price_roundoff(self, state: dict) -> float: - if CalledFunction.objects.filter(function_run=self.get_current_sr()).exists(): + if CalledFunction.objects.filter(function_run=self.current_sr).exists(): return 0 return super().get_price_roundoff(state) diff --git a/recipes/GoogleGPT.py b/recipes/GoogleGPT.py index 6a9eaf999..611483287 100644 --- a/recipes/GoogleGPT.py +++ b/recipes/GoogleGPT.py @@ -254,7 +254,7 @@ def run_v2( }, ), is_user_url=False, - current_user=self.request and self.request.user, + current_user=self.request.user, ) # add pretty titles to references for ref in response.references: diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 41e7a33a8..26b7e785e 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -85,7 +85,6 @@ from daras_ai_v2.prompt_vars import render_prompt_vars from daras_ai_v2.pydantic_validation import FieldHttpUrl from daras_ai_v2.query_generator import generate_final_search_query -from daras_ai_v2.query_params_util import extract_query_params from daras_ai_v2.search_ref import ( parse_refs, CitationStyles, @@ -521,7 +520,7 @@ def render_settings(self): citation_style_selector() gui.checkbox("🔗 Shorten Citation URLs", key="use_url_shortener") - doc_extract_selector(self.request and self.request.user) + doc_extract_selector(self.request.user) gui.write("---") @@ -886,7 +885,7 @@ def run_v2( "keyword_query": response.final_keyword_query, }, ), - current_user=self.request and self.request.user, + current_user=self.request.user, ) if request.use_url_shortener: for reference in response.references: @@ -1061,9 +1060,7 @@ def render_integrations_tab(self): gui.anchor("Get Started", href=self.get_auth_url(), type="primary") return - sr, pr = self.get_runs_from_query_params( - *extract_query_params(gui.get_query_params()) - ) + sr, pr = self.current_sr_pr # make user the user knows that they are on a saved run not the published run if pr and pr.saved_run_id != sr.id: @@ -1377,7 +1374,7 @@ def render_integrations_settings( slack_specific_settings(bi, run_title) if bi.platform == Platform.TWILIO: twilio_specific_settings(bi) - general_integration_settings(bi, self.request.user) + general_integration_settings(bi, self.request) if bi.platform in [Platform.SLACK, Platform.WHATSAPP, Platform.TWILIO]: gui.newline() diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py index 648a2894f..3f3bbacb6 100644 --- a/recipes/VideoBotsStats.py +++ b/recipes/VideoBotsStats.py @@ -84,8 +84,9 @@ def show_title_breadcrumb_share( ) gui.breadcrumb_item( "Integrations", - link_to=VideoBotsPage.current_app_url( - RecipeTabs.integrations, + link_to=VideoBotsPage.app_url( + tab=RecipeTabs.integrations, + query_params=dict(self.request.query_params), path_params=dict( integration_id=bi.api_integration_id() ), @@ -152,7 +153,7 @@ def render(self): ) ) - run_url = VideoBotsPage.current_app_url() + run_url = VideoBotsPage.app_url(query_params=dict(self.request.query_params)) if bi.published_run_id: run_title = bi.published_run.title else: diff --git a/routers/api.py b/routers/api.py index 9b795d426..dd74b5a00 100644 --- a/routers/api.py +++ b/routers/api.py @@ -258,8 +258,13 @@ def get_run_status( run_id: str, user: AppUser = Depends(api_auth_header), ): - self = page_cls() - sr = self.get_sr_from_query_params(example_id=None, run_id=run_id, uid=user.uid) + # init a new page for every request + self = page_cls( + request=SimpleNamespace( + user=user, query_params=dict(run_id=run_id, uid=user.uid) + ) + ) + sr = self.current_sr web_url = str(furl(self.app_url(run_id=run_id, uid=user.uid))) ret = { "run_id": run_id, @@ -335,18 +340,17 @@ def submit_api_call( deduct_credits: bool = True, ) -> tuple[BasePage, "celery.result.AsyncResult", str, str]: # init a new page for every request - self = page_cls(request=SimpleNamespace(user=user)) + query_params.setdefault("uid", user.uid) + self = page_cls(request=SimpleNamespace(user=user, query_params=query_params)) # get saved state from db - query_params.setdefault("uid", user.uid) - sr = self.get_sr_from_query_params_dict(query_params) + sr = self.current_sr state = self.load_state_from_sr(sr) # load request data state.update(request_body) # set streamlit session state gui.set_session_state(state) - gui.set_query_params(query_params) # create a new run try: @@ -369,7 +373,7 @@ def build_async_api_response(*, page: BasePage, run_id: str, uid: str) -> dict: / page.endpoint.replace("v2", "v3") / "status/" ) - sr = page.run_doc_sr(run_id, uid) + sr = page.current_sr return dict( run_id=run_id, web_url=web_url, @@ -388,7 +392,8 @@ def build_sync_api_response( web_url = page.app_url(run_id=run_id, uid=uid) # wait for the result get_celery_result_db_safe(result) - sr = page.run_doc_sr(run_id, uid) + sr = page.current_sr + sr.refresh_from_db() if sr.retention_policy == RetentionPolicy.delete: sr.state = {} sr.save(update_fields=["state"]) diff --git a/routers/bots_api.py b/routers/bots_api.py index 780a7b918..64285cbd2 100644 --- a/routers/bots_api.py +++ b/routers/bots_api.py @@ -302,7 +302,7 @@ def runner(self): msg_handler(self) # raise ValueError("Stream ended") if self.run_id and self.uid: - sr = self.page_cls.run_doc_sr(run_id=self.run_id, uid=self.uid) + sr = self.page_cls.get_sr_from_ids(run_id=self.run_id, uid=self.uid) state = sr.to_dict() self.queue.put( FinalResponse( diff --git a/routers/root.py b/routers/root.py index e234a5443..b2375439d 100644 --- a/routers/root.py +++ b/routers/root.py @@ -39,7 +39,6 @@ from daras_ai_v2.meta_content import build_meta_tags, raw_build_meta_tags from daras_ai_v2.meta_preview_url import meta_preview_url from daras_ai_v2.profiles import user_profile_page, get_meta_tags_for_profile -from daras_ai_v2.query_params_util import extract_query_params from daras_ai_v2.settings import templates from handles.models import Handle from routers.custom_api_router import CustomAPIRouter @@ -314,7 +313,7 @@ def _api_docs_page(request): as_form_data = gui.checkbox("Upload Files via Form Data") page = workflow.page_cls(request=request) - state = page.get_root_published_run().saved_run.to_dict() + state = page.get_root_pr().saved_run.to_dict() api_url, request_body = page.get_example_request(state, include_all=include_all) response_body = page.get_example_response_body( state, as_async=as_async, include_all=include_all @@ -669,12 +668,13 @@ def render_recipe_page( return RedirectResponse(str(new_url.set(origin=None)), status_code=301) # this is because the code still expects example_id to be in the query params - gui.set_query_params(dict(request.query_params) | dict(example_id=example_id)) - _, run_id, uid = extract_query_params(request.query_params) + request._query_params = dict(request.query_params) | dict(example_id=example_id) + + page = page_cls(tab=tab, request=request) + sr = page.current_sr + page.run_user = get_run_user(request, sr.uid) - page = page_cls(tab=tab, request=request, run_user=get_run_user(request, uid)) if not gui.session_state: - sr = page.get_sr_from_query_params(example_id, run_id, uid) gui.session_state.update(page.load_state_from_sr(sr)) with page_wrapper(request): @@ -682,12 +682,7 @@ def render_recipe_page( return dict( meta=build_meta_tags( - url=get_og_url_path(request), - page=page, - state=gui.session_state, - run_id=run_id, - uid=uid, - example_id=example_id, + url=get_og_url_path(request), page=page, state=gui.session_state ), ) @@ -698,7 +693,7 @@ def get_og_url_path(request) -> str: ) -def get_run_user(request, uid) -> AppUser | None: +def get_run_user(request: Request, uid: str) -> AppUser | None: if not uid: return if request.user and request.user.uid == uid: diff --git a/routers/twilio_api.py b/routers/twilio_api.py index 7893c15d1..0223636a9 100644 --- a/routers/twilio_api.py +++ b/routers/twilio_api.py @@ -262,7 +262,7 @@ def resp_say_or_tts_play( tts_state = TextToSpeechPage.RequestModel.parse_obj( {**bot.saved_run.state, "text_prompt": text} ).dict() - result, sr = TextToSpeechPage.get_root_published_run().submit_api_call( + result, sr = TextToSpeechPage.get_root_pr().submit_api_call( current_user=AppUser.objects.get(uid=bot.billing_account_uid), request_body=tts_state, ) diff --git a/tests/test_apis.py b/tests/test_apis.py index fa897eb83..7220798e3 100644 --- a/tests/test_apis.py +++ b/tests/test_apis.py @@ -21,7 +21,7 @@ def test_apis_sync(mock_celery_tasks, force_authentication, threadpool_subtest): def _test_api_sync(page_cls: typing.Type[BasePage]): - state = page_cls.recipe_doc_sr().state + state = page_cls.get_root_pr().saved_run.state r = client.post( f"/v2/{page_cls.slug_versions[0]}/", json=page_cls.get_example_request(state)[1], @@ -38,7 +38,7 @@ def test_apis_async(mock_celery_tasks, force_authentication, threadpool_subtest) def _test_api_async(page_cls: typing.Type[BasePage]): - state = page_cls.recipe_doc_sr().state + state = page_cls.get_root_pr().saved_run.state r = client.post( f"/v3/{page_cls.slug_versions[0]}/async/", diff --git a/tests/test_pricing.py b/tests/test_pricing.py index 6ac6e0591..73e05e4a9 100644 --- a/tests/test_pricing.py +++ b/tests/test_pricing.py @@ -1,3 +1,5 @@ +from types import SimpleNamespace + import gooey_gui as gui import pytest from starlette.testclient import TestClient @@ -46,8 +48,12 @@ def test_copilot_get_raw_price_round_up(): unit_quantity=model_pricing.unit_quantity, dollar_amount=model_pricing.unit_cost * 1 / model_pricing.unit_quantity, ) - copilot_page = VideoBotsPage(run_user=user) - gui.set_query_params({"run_id": bot_saved_run.run_id or "", "uid": user.uid or ""}) + copilot_page = VideoBotsPage( + request=SimpleNamespace( + user=user, + query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""), + ), + ) assert ( copilot_page.get_price_roundoff(state=state) == 210 + copilot_page.PROFIT_CREDITS @@ -107,8 +113,12 @@ def test_multiple_llm_sums_usage_cost(): dollar_amount=model_pricing2.unit_cost * 1 / model_pricing2.unit_quantity, ) - llm_page = CompareLLMPage(run_user=user) - gui.set_query_params({"run_id": bot_saved_run.run_id or "", "uid": user.uid or ""}) + llm_page = CompareLLMPage( + request=SimpleNamespace( + user=user, + query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""), + ) + ) assert llm_page.get_price_roundoff(state=state) == (310 + llm_page.PROFIT_CREDITS) @@ -152,8 +162,12 @@ def test_workflowmetadata_2x_multiplier(): metadata.price_multiplier = 2 metadata.save() - llm_page = CompareLLMPage(run_user=user) - gui.set_query_params({"run_id": bot_saved_run.run_id or "", "uid": user.uid or ""}) + llm_page = CompareLLMPage( + request=SimpleNamespace( + user=user, + query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""), + ) + ) assert ( llm_page.get_price_roundoff(state=state) == (210 + llm_page.PROFIT_CREDITS) * 2 ) diff --git a/url_shortener/models.py b/url_shortener/models.py index 21b0864d7..515b293a2 100644 --- a/url_shortener/models.py +++ b/url_shortener/models.py @@ -6,10 +6,9 @@ from app_users.models import AppUser from bots.custom_fields import CustomURLField from bots.models import Workflow, SavedRun +from celeryapp.tasks import get_running_saved_run from daras_ai.image_input import truncate_filename from daras_ai_v2 import settings -from daras_ai_v2.query_params_util import extract_query_params -import gooey_gui as gui class ShortenedURLQuerySet(models.QuerySet): @@ -17,14 +16,8 @@ def get_or_create_for_workflow( self, *, user: AppUser, workflow: Workflow, **kwargs ) -> tuple["ShortenedURL", bool]: surl, created = self.filter_first_or_create(user=user, **kwargs) - _, run_id, uid = extract_query_params(gui.get_query_params()) - surl.saved_runs.add( - SavedRun.objects.get_or_create( - workflow=workflow, - run_id=run_id, - uid=uid, - )[0], - ) + sr = get_running_saved_run() + surl.saved_runs.add(sr) return surl, created def filter_first_or_create(self, defaults=None, **kwargs): diff --git a/usage_costs/cost_utils.py b/usage_costs/cost_utils.py index 6596c0d22..b380ee63f 100644 --- a/usage_costs/cost_utils.py +++ b/usage_costs/cost_utils.py @@ -1,19 +1,16 @@ from loguru import logger -from daras_ai_v2.query_params_util import extract_query_params +from celeryapp.tasks import get_running_saved_run from usage_costs.models import ( UsageCost, ModelSku, ModelPricing, ) -import gooey_gui as gui def record_cost_auto(model: str, sku: ModelSku, quantity: int): - from bots.models import SavedRun - - _, run_id, uid = extract_query_params(gui.get_query_params()) - if not run_id or not uid: + sr = get_running_saved_run() + if not sr: return try: @@ -22,10 +19,8 @@ def record_cost_auto(model: str, sku: ModelSku, quantity: int): logger.warning(f"Cant find pricing for {model=} {sku=}: {e=}") return - saved_run = SavedRun.objects.get(run_id=run_id, uid=uid) - UsageCost.objects.create( - saved_run=saved_run, + saved_run=sr, pricing=pricing, quantity=quantity, unit_cost=pricing.unit_cost, From 7f4ed8340983c663c3ccd133b41794268c25cea9 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Fri, 30 Aug 2024 20:17:19 +0530 Subject: [PATCH 2/4] Refactor load_state_from_sr method to current_sr_to_session_state across the codebase --- daras_ai_v2/base.py | 7 +++---- recipes/asr_page.py | 4 ++-- routers/api.py | 3 +-- routers/root.py | 5 ++--- 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index d1f3715c1..e2f05a66f 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -1683,12 +1683,11 @@ def _render_after_output(self): gui.session_state[StateKeys.pressed_randomize] = True gui.rerun() - @classmethod - def load_state_from_sr(cls, sr: SavedRun) -> dict: - state = sr.to_dict() + def current_sr_to_session_state(self) -> dict: + state = self.current_sr.to_dict() if state is None: raise HTTPException(status_code=404) - return cls.load_state_defaults(state) + return self.load_state_defaults(state) @classmethod def load_state_defaults(cls, state: dict): diff --git a/recipes/asr_page.py b/recipes/asr_page.py index 58d49ffa4..f68be1a08 100644 --- a/recipes/asr_page.py +++ b/recipes/asr_page.py @@ -64,8 +64,8 @@ class ResponseModel(BaseModel): raw_output_text: list[str] | None output_text: list[str | AsrOutputJson] - def load_state_from_sr(self, sr: SavedRun) -> dict: - state = super().load_state_from_sr(sr) + def current_sr_to_session_state(self) -> dict: + state = super().current_sr_to_session_state() google_translate_target = state.pop("google_translate_target", None) translation_model = state.get("translation_model") if google_translate_target and not translation_model: diff --git a/routers/api.py b/routers/api.py index dd74b5a00..d1878bd53 100644 --- a/routers/api.py +++ b/routers/api.py @@ -344,8 +344,7 @@ def submit_api_call( self = page_cls(request=SimpleNamespace(user=user, query_params=query_params)) # get saved state from db - sr = self.current_sr - state = self.load_state_from_sr(sr) + state = self.current_sr_to_session_state() # load request data state.update(request_body) diff --git a/routers/root.py b/routers/root.py index b2375439d..9f678416b 100644 --- a/routers/root.py +++ b/routers/root.py @@ -671,11 +671,10 @@ def render_recipe_page( request._query_params = dict(request.query_params) | dict(example_id=example_id) page = page_cls(tab=tab, request=request) - sr = page.current_sr - page.run_user = get_run_user(request, sr.uid) + page.run_user = get_run_user(request, page.current_sr.uid) if not gui.session_state: - gui.session_state.update(page.load_state_from_sr(sr)) + gui.session_state.update(page.current_sr_to_session_state()) with page_wrapper(request): page.render() From 68cfbe89911ce2dca2a22bb8b176c95401c25ca4 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Fri, 6 Sep 2024 14:00:12 +0530 Subject: [PATCH 3/4] Refactor consistent usage of submit_api_call() - Introduce `SavedRun.wait_for_celery_result` to encapsulate common logic. - Change BasePage `endpoint` to method `api_endpoint`. - Update `submit_api_call`, `build_sync_api_response`, and `build_async_api_response` signatures. --- bots/models.py | 13 ++++--- bots/tasks.py | 4 +- daras_ai_v2/base.py | 7 ++-- daras_ai_v2/bots.py | 22 +++++------ daras_ai_v2/safety_checker.py | 3 +- functions/recipe_functions.py | 3 +- recipes/BulkRunner.py | 12 +++--- routers/api.py | 72 +++++++++++++++++------------------ routers/bots_api.py | 14 +++---- routers/twilio_api.py | 3 +- 10 files changed, 72 insertions(+), 81 deletions(-) diff --git a/bots/models.py b/bots/models.py index 3ace6bc33..6c3c1d92b 100644 --- a/bots/models.py +++ b/bots/models.py @@ -18,6 +18,7 @@ from daras_ai_v2.crypto import get_random_doc_id from daras_ai_v2.language_model import format_chat_entry from functions.models import CalledFunction, CalledFunctionResponse +from gooeysite.bg_db_conn import get_celery_result_db_safe from gooeysite.custom_create import get_or_create_lazy if typing.TYPE_CHECKING: @@ -358,8 +359,8 @@ def submit_api_call( current_user: AppUser, request_body: dict, enable_rate_limits: bool = False, - parent_pr: "PublishedRun" = None, deduct_credits: bool = True, + parent_pr: "PublishedRun" = None, ) -> tuple["celery.result.AsyncResult", "SavedRun"]: from routers.api import submit_api_call @@ -373,19 +374,21 @@ def submit_api_call( query_params = page_cls.clean_query_params( example_id=self.example_id, run_id=self.run_id, uid=self.uid ) - page, result, run_id, uid = pool.apply( + return pool.apply( submit_api_call, kwds=dict( page_cls=page_cls, query_params=query_params, - user=current_user, + current_user=current_user, request_body=request_body, enable_rate_limits=enable_rate_limits, deduct_credits=deduct_credits, ), ) - return result, page.current_sr + def wait_for_celery_result(self, result: "celery.result.AsyncResult"): + get_celery_result_db_safe(result) + self.refresh_from_db() def get_creator(self) -> AppUser | None: if self.uid: @@ -1839,8 +1842,8 @@ def submit_api_call( current_user=current_user, request_body=request_body, enable_rate_limits=enable_rate_limits, - parent_pr=self, deduct_credits=deduct_credits, + parent_pr=self, ) diff --git a/bots/tasks.py b/bots/tasks.py index 70f381bf1..0a76e42bc 100644 --- a/bots/tasks.py +++ b/bots/tasks.py @@ -97,9 +97,7 @@ def msg_analysis(self, msg_id: int, anal_id: int, countdown: int | None): # save the run before the result is ready Message.objects.filter(id=msg_id).update(analysis_run=sr) - # wait for the result - get_celery_result_db_safe(result) - sr.refresh_from_db() + sr.wait_for_celery_result(result) # if failed, raise error if sr.error_msg: raise RuntimeError(sr.error_msg) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index e2f05a66f..90432553c 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -164,8 +164,7 @@ def __init__( self.run_user = run_user @classmethod - @property - def endpoint(cls) -> str: + def api_endpoint(cls) -> str: return f"/v2/{cls.slug_versions[0]}" def current_app_url( @@ -241,7 +240,9 @@ def api_url( query_params = dict(run_id=run_id, uid=uid) elif example_id: query_params = dict(example_id=example_id) - return furl(settings.API_BASE_URL, query_params=query_params) / cls.endpoint + return ( + furl(settings.API_BASE_URL, query_params=query_params) / cls.api_endpoint() + ) @classmethod def clean_query_params(cls, *, example_id, run_id, uid) -> dict: diff --git a/daras_ai_v2/bots.py b/daras_ai_v2/bots.py index eea020936..e8fb61c6b 100644 --- a/daras_ai_v2/bots.py +++ b/daras_ai_v2/bots.py @@ -1,6 +1,7 @@ import mimetypes import typing from datetime import datetime +from types import SimpleNamespace import gooey_gui as gui from django.db import transaction @@ -199,9 +200,7 @@ def get_input_documents(self) -> list[str] | None: def get_interactive_msg_info(self) -> ButtonPressed: raise NotImplementedError("This bot does not support interactive messages.") - def on_run_created( - self, page: BasePage, result: "celery.result.AsyncResult", run_id: str, uid: str - ): + def on_run_created(self, sr: "SavedRun"): pass def send_run_status(self, update_msg_id: str | None) -> str | None: @@ -376,13 +375,13 @@ def _process_and_send_msg( variables.update(bot.request_overrides["variables"]) except KeyError: pass - page, result, run_id, uid = submit_api_call( + result, sr = submit_api_call( page_cls=bot.page_cls, - user=billing_account_user, - request_body=body, query_params=bot.query_params, + current_user=billing_account_user, + request_body=body, ) - bot.on_run_created(page, result, run_id, uid) + bot.on_run_created(sr) if bot.show_feedback_buttons: buttons = _feedback_start_buttons() @@ -394,10 +393,10 @@ def _process_and_send_msg( last_idx = 0 # this is the last index of the text sent to the user if bot.streaming_enabled: # subscribe to the realtime channel for updates - channel = page.realtime_channel_name(run_id, uid) + channel = bot.page_cls.realtime_channel_name(sr.run_id, sr.uid) with gui.realtime_subscribe(channel) as realtime_gen: for state in realtime_gen: - bot.recipe_run_state = page.get_run_state(state) + bot.recipe_run_state = bot.page_cls.get_run_state(state) bot.run_status = state.get(StateKeys.run_status) or "" # check for errors if bot.recipe_run_state == RecipeRunState.failed: @@ -438,11 +437,10 @@ def _process_and_send_msg( break # we're done streaming, stop the loop # wait for the celery task to finish - get_celery_result_db_safe(result) + sr.wait_for_celery_result(result) # get the final state from db - sr = page.current_sr state = sr.to_dict() - bot.recipe_run_state = page.get_run_state(state) + bot.recipe_run_state = bot.page_cls.get_run_state(state) bot.run_status = state.get(StateKeys.run_status) or "" # check for errors err_msg = state.get(StateKeys.error_msg) diff --git a/daras_ai_v2/safety_checker.py b/daras_ai_v2/safety_checker.py index 8a8d67336..7faf6b66b 100644 --- a/daras_ai_v2/safety_checker.py +++ b/daras_ai_v2/safety_checker.py @@ -38,8 +38,7 @@ def safety_checker_text(text_input: str): ) # wait for checker - get_celery_result_db_safe(result) - sr.refresh_from_db() + sr.wait_for_celery_result(result) # if checker failed, raise error if sr.error_msg: raise RuntimeError(sr.error_msg) diff --git a/functions/recipe_functions.py b/functions/recipe_functions.py index b7fd36fdb..21d3fa185 100644 --- a/functions/recipe_functions.py +++ b/functions/recipe_functions.py @@ -63,8 +63,7 @@ def call_recipe_functions( # wait for the result if its a pre request function if trigger == FunctionTrigger.post: continue - get_celery_result_db_safe(result) - sr.refresh_from_db() + sr.wait_for_celery_result(result) # if failed, raise error if sr.error_msg: raise RuntimeError(sr.error_msg) diff --git a/recipes/BulkRunner.py b/recipes/BulkRunner.py index f700effac..7ad9e67e9 100644 --- a/recipes/BulkRunner.py +++ b/recipes/BulkRunner.py @@ -2,10 +2,10 @@ import typing import uuid +import gooey_gui as gui from furl import furl from pydantic import BaseModel, Field -import gooey_gui as gui from bots.models import Workflow, SavedRun from daras_ai.image_input import upload_file_from_bytes from daras_ai_v2 import icons @@ -322,8 +322,7 @@ def run_v2( request_body=request_body, parent_pr=pr, ) - get_celery_result_db_safe(result) - sr.refresh_from_db() + sr.wait_for_celery_result(result) run_time = datetime.timedelta( seconds=int(sr.run_time.total_seconds()) @@ -390,10 +389,11 @@ def run_v2( documents=response.output_documents ).dict(exclude_unset=True) result, sr = sr.submit_api_call( - current_user=self.request.user, request_body=request_body, parent_pr=pr + current_user=self.request.user, + request_body=request_body, + parent_pr=pr, ) - get_celery_result_db_safe(result) - sr.refresh_from_db() + sr.wait_for_celery_result(result) response.eval_runs.append(sr.get_app_url()) def preview_description(self, state: dict) -> str: diff --git a/routers/api.py b/routers/api.py index d1878bd53..68ad4c2b6 100644 --- a/routers/api.py +++ b/routers/api.py @@ -31,7 +31,7 @@ from app_users.models import AppUser from auth.token_authentication import api_auth_header -from bots.models import RetentionPolicy +from bots.models import RetentionPolicy, Workflow from daras_ai.image_input import upload_file_from_bytes from daras_ai_v2 import settings from daras_ai_v2.all_pages import all_api_pages @@ -41,9 +41,12 @@ ) from daras_ai_v2.fastapi_tricks import fastapi_request_form from functions.models import CalledFunctionResponse -from gooeysite.bg_db_conn import get_celery_result_db_safe from routers.custom_api_router import CustomAPIRouter +if typing.TYPE_CHECKING: + from bots.models import SavedRun + import celery.result + app = CustomAPIRouter() @@ -117,7 +120,7 @@ class RunSettings(BaseModel): def script_to_api(page_cls: typing.Type[BasePage]): - endpoint = page_cls().endpoint.rstrip("/") + endpoint = page_cls.api_endpoint().rstrip("/") # add the common settings to the request model request_model = create_model( page_cls.__name__ + "Request", @@ -156,15 +159,15 @@ def run_api_json( page_request: request_model, user: AppUser = Depends(api_auth_header), ): - page, result, run_id, uid = submit_api_call( + result, sr = submit_api_call( page_cls=page_cls, - user=user, - request_body=page_request.dict(exclude_unset=True), query_params=dict(request.query_params), retention_policy=RetentionPolicy[page_request.settings.retention_policy], + current_user=user, + request_body=page_request.dict(exclude_unset=True), enable_rate_limits=True, ) - return build_sync_api_response(page=page, result=result, run_id=run_id, uid=uid) + return build_sync_api_response(result, sr) @app.post( os.path.join(endpoint, "form"), @@ -205,15 +208,15 @@ def run_api_json_async( page_request: request_model, user: AppUser = Depends(api_auth_header), ): - page, _, run_id, uid = submit_api_call( + result, sr = submit_api_call( page_cls=page_cls, - user=user, - request_body=page_request.dict(exclude_unset=True), query_params=dict(request.query_params), retention_policy=RetentionPolicy[page_request.settings.retention_policy], + current_user=user, + request_body=page_request.dict(exclude_unset=True), enable_rate_limits=True, ) - ret = build_async_api_response(page=page, run_id=run_id, uid=uid) + ret = build_async_api_response(sr) response.headers["Location"] = ret["status_url"] response.headers["Access-Control-Expose-Headers"] = "Location" return ret @@ -332,19 +335,21 @@ def _parse_form_data( def submit_api_call( *, page_cls: typing.Type[BasePage], - request_body: dict, - user: AppUser, query_params: dict, retention_policy: RetentionPolicy = None, + current_user: AppUser, + request_body: dict, enable_rate_limits: bool = False, deduct_credits: bool = True, -) -> tuple[BasePage, "celery.result.AsyncResult", str, str]: +) -> tuple["celery.result.AsyncResult", "SavedRun"]: # init a new page for every request - query_params.setdefault("uid", user.uid) - self = page_cls(request=SimpleNamespace(user=user, query_params=query_params)) + query_params.setdefault("uid", current_user.uid) + page = page_cls( + request=SimpleNamespace(user=current_user, query_params=query_params) + ) # get saved state from db - state = self.current_sr_to_session_state() + state = page.current_sr_to_session_state() # load request data state.update(request_body) @@ -353,7 +358,7 @@ def submit_api_call( # create a new run try: - sr = self.create_new_run( + sr = page.create_new_run( enable_rate_limits=enable_rate_limits, is_api_call=True, retention_policy=retention_policy or RetentionPolicy.keep, @@ -361,20 +366,19 @@ def submit_api_call( except ValidationError as e: raise RequestValidationError(e.raw_errors, body=gui.session_state) from e # submit the task - result = self.call_runner_task(sr, deduct_credits=deduct_credits) - return self, result, sr.run_id, sr.uid + result = page.call_runner_task(sr, deduct_credits=deduct_credits) + return result, sr -def build_async_api_response(*, page: BasePage, run_id: str, uid: str) -> dict: - web_url = page.app_url(run_id=run_id, uid=uid) +def build_async_api_response(sr: "SavedRun") -> dict: + web_url = sr.get_app_url() status_url = str( - furl(settings.API_BASE_URL, query_params=dict(run_id=run_id)) - / page.endpoint.replace("v2", "v3") + furl(settings.API_BASE_URL, query_params=dict(run_id=sr.run_id)) + / Workflow(sr.workflow).page_cls.api_endpoint().replace("v2", "v3") / "status/" ) - sr = page.current_sr return dict( - run_id=run_id, + run_id=sr.run_id, web_url=web_url, created_at=sr.created_at.isoformat(), status_url=status_url, @@ -382,17 +386,11 @@ def build_async_api_response(*, page: BasePage, run_id: str, uid: str) -> dict: def build_sync_api_response( - *, - page: BasePage, - result: "celery.result.AsyncResult", - run_id: str, - uid: str, + result: "celery.result.AsyncResult", sr: "SavedRun" ) -> JSONResponse: - web_url = page.app_url(run_id=run_id, uid=uid) + web_url = sr.get_app_url() # wait for the result - get_celery_result_db_safe(result) - sr = page.current_sr - sr.refresh_from_db() + sr.wait_for_celery_result(result) if sr.retention_policy == RetentionPolicy.delete: sr.state = {} sr.save(update_fields=["state"]) @@ -401,7 +399,7 @@ def build_sync_api_response( return JSONResponse( dict( detail=dict( - id=run_id, + id=sr.run_id, url=web_url, created_at=sr.created_at.isoformat(), error=sr.error_msg, @@ -414,7 +412,7 @@ def build_sync_api_response( return JSONResponse( jsonable_encoder( dict( - id=run_id, + id=sr.run_id, url=web_url, created_at=sr.created_at.isoformat(), output=sr.api_output(), diff --git a/routers/bots_api.py b/routers/bots_api.py index 64285cbd2..021caae4b 100644 --- a/routers/bots_api.py +++ b/routers/bots_api.py @@ -10,7 +10,7 @@ from pydantic import BaseModel, Field from starlette.responses import StreamingResponse, Response -from bots.models import Platform, Conversation, BotIntegration, Message +from bots.models import Platform, Conversation, BotIntegration, Message, SavedRun from celeryapp.tasks import err_msg_for_exc from daras_ai_v2 import settings from daras_ai_v2.base import RecipeRunState, BasePage, StateKeys @@ -320,14 +320,10 @@ def runner(self): finally: self.queue.put(None) - def on_run_created( - self, page: BasePage, result: "celery.result.AsyncResult", run_id: str, uid: str - ): - self.run_id = run_id - self.uid = uid - self.queue.put( - RunStart(**build_async_api_response(page=page, run_id=run_id, uid=uid)) - ) + def on_run_created(self, sr: SavedRun): + self.run_id = sr.run_id + self.uid = sr.uid + self.queue.put(RunStart(**build_async_api_response(sr))) def send_run_status(self, update_msg_id: str | None) -> str | None: self.queue.put( diff --git a/routers/twilio_api.py b/routers/twilio_api.py index 0223636a9..108f3e47b 100644 --- a/routers/twilio_api.py +++ b/routers/twilio_api.py @@ -267,8 +267,7 @@ def resp_say_or_tts_play( request_body=tts_state, ) # wait for the TTS to finish - get_celery_result_db_safe(result) - sr.refresh_from_db() + sr.wait_for_celery_result(result) # check for errors if sr.error_msg: raise RuntimeError(sr.error_msg) From c80fb7c0032b1bcd1ecb9cd7dab6c7dc653fe398 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Fri, 6 Sep 2024 15:19:15 +0530 Subject: [PATCH 4/4] Remove usage of `SimpleNamespace` for request handling in BasePage Move BasePage.run_user -> cached property current_sr_user --- bots/admin.py | 7 ++---- celeryapp/tasks.py | 6 +---- daras_ai_v2/base.py | 54 +++++++++++++++++++++++++++++-------------- recipes/LipsyncTTS.py | 6 ++--- recipes/VideoBots.py | 8 ++----- routers/api.py | 11 ++------- routers/root.py | 29 ++++++++--------------- tests/test_pricing.py | 21 +++++------------ 8 files changed, 62 insertions(+), 80 deletions(-) diff --git a/bots/admin.py b/bots/admin.py index b7e696dfd..8b52e233f 100644 --- a/bots/admin.py +++ b/bots/admin.py @@ -1,6 +1,5 @@ import datetime import json -from types import SimpleNamespace import django.db.models from django import forms @@ -439,10 +438,8 @@ def rerun_tasks(self, request, queryset): sr: SavedRun for sr in queryset.all(): page = Workflow(sr.workflow).page_cls( - request=SimpleNamespace( - user=AppUser.objects.get(uid=sr.uid), - query_params=dict(run_id=sr.run_id, uid=sr.uid), - ) + user=AppUser.objects.get(uid=sr.uid), + query_params=dict(run_id=sr.run_id, uid=sr.uid), ) page.call_runner_task(sr, deduct_credits=False) self.message_user( diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index b2e257365..3fed1442f 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -4,7 +4,6 @@ import traceback import typing from time import time -from types import SimpleNamespace import gooey_gui as gui import requests @@ -92,10 +91,7 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False page.dump_state_to_sr(gui.session_state | output, sr) page = page_cls( - request=SimpleNamespace( - user=AppUser.objects.get(id=user_id), - query_params=dict(run_id=run_id, uid=uid), - ), + user=AppUser.objects.get(id=user_id), query_params=dict(run_id=run_id, uid=uid) ) page.setup_sentry() sr = page.current_sr diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 90432553c..870debdb1 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -12,7 +12,6 @@ from itertools import pairwise from random import Random from time import sleep -from types import SimpleNamespace import gooey_gui as gui import sentry_sdk @@ -26,7 +25,6 @@ from sentry_sdk.tracing import ( TRANSACTION_SOURCE_ROUTE, ) -from starlette.requests import Request from app_users.models import AppUser, AppUserTransaction from bots.models import ( @@ -94,7 +92,6 @@ MAX_SEED = 4294967294 gooey_rng = Random() - SUBMIT_AFTER_LOGIN_Q = "submitafterlogin" @@ -117,6 +114,12 @@ class StateKeys: hidden = "__hidden" +class BasePageRequest: + user: AppUser | None + session: dict + query_params: dict + + class BasePage: title: str workflow: Workflow @@ -154,14 +157,20 @@ def __init__( self, *, tab: RecipeTabs = RecipeTabs.run, - request: Request | SimpleNamespace | None = None, - run_user: AppUser | None = None, + request: BasePageRequest | None = None, + user: AppUser | None = None, + request_session: dict | None = None, + query_params: dict | None = None, ): - if request is None: - request = SimpleNamespace(user=None, query_params={}) self.tab = tab + + if not request: + request = BasePageRequest() + request.user = user + request.session = request_session or {} + request.query_params = query_params or {} + self.request = request - self.run_user = run_user @classmethod def api_endpoint(cls) -> str: @@ -349,7 +358,7 @@ def _render_header(self): with gui.div(className="d-flex justify-content-between mt-4"): with gui.div(className="d-lg-flex d-block align-items-center"): - if not tbreadcrumbs.has_breadcrumbs() and not self.run_user: + if not tbreadcrumbs.has_breadcrumbs() and not self.current_sr_user: self._render_title(tbreadcrumbs.h1_title) if tbreadcrumbs: @@ -362,7 +371,7 @@ def _render_header(self): if is_example: author = pr.created_by else: - author = self.run_user or sr.get_creator() + author = self.current_sr_user or sr.get_creator() if not is_root_example: self.render_author(author) @@ -386,7 +395,7 @@ def _render_header(self): self._render_published_run_save_buttons(sr=sr, pr=pr) self._render_social_buttons(show_button_text=not show_save_buttons) - if tbreadcrumbs.has_breadcrumbs() or self.run_user: + if tbreadcrumbs.has_breadcrumbs() or self.current_sr_user: # only render title here if the above row was not empty self._render_title(tbreadcrumbs.h1_title) @@ -810,7 +819,7 @@ def get_explore_image(self) -> str: return meta_preview_url(img, fallback_img) def _user_disabled_check(self): - if self.run_user and self.run_user.is_disabled: + if self.current_sr_user and self.current_sr_user.is_disabled: msg = ( "This Gooey.AI account has been disabled for violating our [Terms of Service](/terms). " "Contact us at support@gooey.ai if you think this is a mistake." @@ -1009,7 +1018,7 @@ def render_report_form(self): send_reported_run_email( user=self.request.user, - run_uid=str(self.run_user.uid), + run_uid=str(self.current_sr_user.uid), url=self.current_app_url(), recipe_name=self.title, report_type=report_type, @@ -1052,11 +1061,22 @@ def update_flag_for_run(self, is_flagged: bool): sr.save(update_fields=["is_flagged"]) gui.session_state["is_flagged"] = is_flagged - @property + @cached_property + def current_sr_user(self) -> AppUser | None: + if not self.current_sr.uid: + return None + if self.request.user and self.request.user.uid == self.current_sr.uid: + return self.request.user + try: + return AppUser.objects.get(uid=self.current_sr.uid) + except AppUser.DoesNotExist: + return None + + @cached_property def current_sr(self) -> SavedRun: return self.current_sr_pr[0] - @property + @cached_property def current_pr(self) -> PublishedRun: return self.current_sr_pr[1] @@ -1571,7 +1591,7 @@ def create_new_run( uid = self.request.user.uid else: uid = auth.create_user().uid - self.request.scope["user"] = AppUser.objects.create( + self.request.user = AppUser.objects.create( uid=uid, is_anonymous=True, balance=settings.ANON_USER_FREE_CREDITS ) self.request.session[ANONYMOUS_USER_COOKIE] = dict(uid=uid) @@ -2138,7 +2158,7 @@ def is_current_user_paying(self) -> bool: return bool(self.request.user and self.request.user.is_paying) def is_current_user_owner(self) -> bool: - return bool(self.request.user and self.run_user == self.request.user) + return bool(self.request.user and self.current_sr_user == self.request.user) def started_at_text(dt: datetime.datetime): diff --git a/recipes/LipsyncTTS.py b/recipes/LipsyncTTS.py index 995f5d43f..d557cd663 100644 --- a/recipes/LipsyncTTS.py +++ b/recipes/LipsyncTTS.py @@ -122,12 +122,10 @@ def run(self, state: dict) -> typing.Iterator[str | None]: if not self.request.user.disable_safety_checker: safety_checker(text=state["text_prompt"]) - yield from TextToSpeechPage(request=self.request, run_user=self.run_user).run( - state - ) + yield from TextToSpeechPage(request=self.request).run(state) # IMP: Copy output of TextToSpeechPage "audio_url" to Lipsync as "input_audio" state["input_audio"] = state["audio_url"] - yield from LipsyncPage(request=self.request, run_user=self.run_user).run(state) + yield from LipsyncPage(request=self.request).run(state) def render_example(self, state: dict): output_video = state.get("output_video") diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 26b7e785e..77b266c26 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -1015,9 +1015,7 @@ def run_v2( tts_state = TextToSpeechPage.RequestModel.parse_obj( {**gui.session_state, "text_prompt": text} ).dict() - yield from TextToSpeechPage( - request=self.request, run_user=self.run_user - ).run(tts_state) + yield from TextToSpeechPage(request=self.request).run(tts_state) response.output_audio.append(tts_state["audio_url"]) if not request.input_face: @@ -1031,9 +1029,7 @@ def run_v2( "selected_model": request.lipsync_model, } ).dict() - yield from LipsyncPage(request=self.request, run_user=self.run_user).run( - lip_state - ) + yield from LipsyncPage(request=self.request).run(lip_state) response.output_video.append(lip_state["output_video"]) def get_tabs(self): diff --git a/routers/api.py b/routers/api.py index 68ad4c2b6..73c5ec7b9 100644 --- a/routers/api.py +++ b/routers/api.py @@ -3,7 +3,6 @@ import os.path import os.path import typing -from types import SimpleNamespace import gooey_gui as gui from fastapi import Depends @@ -262,11 +261,7 @@ def get_run_status( user: AppUser = Depends(api_auth_header), ): # init a new page for every request - self = page_cls( - request=SimpleNamespace( - user=user, query_params=dict(run_id=run_id, uid=user.uid) - ) - ) + self = page_cls(user=user, query_params=dict(run_id=run_id, uid=user.uid)) sr = self.current_sr web_url = str(furl(self.app_url(run_id=run_id, uid=user.uid))) ret = { @@ -344,9 +339,7 @@ def submit_api_call( ) -> tuple["celery.result.AsyncResult", "SavedRun"]: # init a new page for every request query_params.setdefault("uid", current_user.uid) - page = page_cls( - request=SimpleNamespace(user=current_user, query_params=query_params) - ) + page = page_cls(user=current_user, query_params=query_params) # get saved state from db state = page.current_sr_to_session_state() diff --git a/routers/root.py b/routers/root.py index 9f678416b..40884fd55 100644 --- a/routers/root.py +++ b/routers/root.py @@ -242,7 +242,7 @@ def explore_page(request: Request): @gui.route(app, "/api/") def api_docs_page(request: Request): with page_wrapper(request): - _api_docs_page(request) + _api_docs_page() return dict( meta=raw_build_meta_tags( url=get_og_url_path(request), @@ -255,7 +255,7 @@ def api_docs_page(request: Request): ) -def _api_docs_page(request): +def _api_docs_page(): from daras_ai_v2.all_pages import all_api_pages api_docs_url = str(furl(settings.API_BASE_URL) / "docs") @@ -312,7 +312,7 @@ def _api_docs_page(request): as_async = gui.checkbox("Run Async") as_form_data = gui.checkbox("Upload Files via Form Data") - page = workflow.page_cls(request=request) + page = workflow.page_cls() state = page.get_root_pr().saved_run.to_dict() api_url, request_body = page.get_example_request(state, include_all=include_all) response_body = page.get_example_response_body( @@ -667,11 +667,13 @@ def render_recipe_page( ) return RedirectResponse(str(new_url.set(origin=None)), status_code=301) - # this is because the code still expects example_id to be in the query params - request._query_params = dict(request.query_params) | dict(example_id=example_id) - - page = page_cls(tab=tab, request=request) - page.run_user = get_run_user(request, page.current_sr.uid) + page = page_cls( + tab=tab, + user=request.user, + request_session=request.session, + # this is because the code still expects example_id to be in the query params + query_params=dict(request.query_params) | dict(example_id=example_id), + ) if not gui.session_state: gui.session_state.update(page.current_sr_to_session_state()) @@ -692,17 +694,6 @@ def get_og_url_path(request) -> str: ) -def get_run_user(request: Request, uid: str) -> AppUser | None: - if not uid: - return - if request.user and request.user.uid == uid: - return request.user - try: - return AppUser.objects.get(uid=uid) - except AppUser.DoesNotExist: - pass - - @contextmanager def page_wrapper(request: Request, className=""): context = { diff --git a/tests/test_pricing.py b/tests/test_pricing.py index 73e05e4a9..c78bf5376 100644 --- a/tests/test_pricing.py +++ b/tests/test_pricing.py @@ -1,6 +1,3 @@ -from types import SimpleNamespace - -import gooey_gui as gui import pytest from starlette.testclient import TestClient @@ -49,10 +46,8 @@ def test_copilot_get_raw_price_round_up(): dollar_amount=model_pricing.unit_cost * 1 / model_pricing.unit_quantity, ) copilot_page = VideoBotsPage( - request=SimpleNamespace( - user=user, - query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""), - ), + user=user, + query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""), ) assert ( copilot_page.get_price_roundoff(state=state) @@ -114,10 +109,8 @@ def test_multiple_llm_sums_usage_cost(): ) llm_page = CompareLLMPage( - request=SimpleNamespace( - user=user, - query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""), - ) + user=user, + query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""), ) assert llm_page.get_price_roundoff(state=state) == (310 + llm_page.PROFIT_CREDITS) @@ -163,10 +156,8 @@ def test_workflowmetadata_2x_multiplier(): metadata.save() llm_page = CompareLLMPage( - request=SimpleNamespace( - user=user, - query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""), - ) + user=user, + query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""), ) assert ( llm_page.get_price_roundoff(state=state) == (210 + llm_page.PROFIT_CREDITS) * 2