diff --git a/bots/admin.py b/bots/admin.py index eaf200b04..b7e696dfd 100644 --- a/bots/admin.py +++ b/bots/admin.py @@ -439,7 +439,10 @@ def rerun_tasks(self, request, queryset): sr: SavedRun for sr in queryset.all(): page = Workflow(sr.workflow).page_cls( - request=SimpleNamespace(user=AppUser.objects.get(uid=sr.uid)) + request=SimpleNamespace( + user=AppUser.objects.get(uid=sr.uid), + query_params=dict(run_id=sr.run_id, uid=sr.uid), + ) ) page.call_runner_task(sr, deduct_credits=False) self.message_user( diff --git a/bots/models.py b/bots/models.py index e997e8f8a..3ace6bc33 100644 --- a/bots/models.py +++ b/bots/models.py @@ -127,16 +127,12 @@ def get_or_create_metadata(self) -> "WorkflowMetadata": workflow=self, create=lambda **kwargs: WorkflowMetadata.objects.create( **kwargs, - short_title=( - self.page_cls.get_root_published_run().title or self.page_cls.title - ), + short_title=(self.page_cls.get_root_pr().title or self.page_cls.title), default_image=self.page_cls.explore_image or "", - meta_title=( - self.page_cls.get_root_published_run().title or self.page_cls.title - ), + meta_title=(self.page_cls.get_root_pr().title or self.page_cls.title), meta_description=( self.page_cls().preview_description(state={}) - or self.page_cls.get_root_published_run().notes + or self.page_cls.get_root_pr().notes ), meta_image=self.page_cls.explore_image or "", ), @@ -389,7 +385,7 @@ def submit_api_call( ), ) - return result, page.run_doc_sr(run_id, uid) + return result, page.current_sr def get_creator(self) -> AppUser | None: if self.uid: diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index 221754ddb..235b14138 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -1,5 +1,6 @@ import datetime import html +import threading import traceback import typing from time import time @@ -31,6 +32,15 @@ DEFAULT_RUN_STATUS = "Running..." +threadlocal = threading.local() + + +def get_running_saved_run() -> SavedRun | None: + try: + return threadlocal.saved_run + except AttributeError: + return None + @app.task def runner_task( @@ -81,12 +91,16 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False # save to db page.dump_state_to_sr(gui.session_state | output, sr) - user = AppUser.objects.get(id=user_id) - page = page_cls(request=SimpleNamespace(user=user)) + page = page_cls( + request=SimpleNamespace( + user=AppUser.objects.get(id=user_id), + query_params=dict(run_id=run_id, uid=uid), + ), + ) page.setup_sentry() - sr = page.run_doc_sr(run_id, uid) + sr = page.current_sr + threadlocal.saved_run = sr gui.set_session_state(sr.to_dict() | (unsaved_state or {})) - gui.set_query_params(dict(run_id=run_id, uid=uid)) try: save_on_step() @@ -114,6 +128,7 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False # save everything, mark run as completed finally: save_on_step(done=True) + threadlocal.saved_run = None post_runner_tasks.delay(sr.id) diff --git a/conftest.py b/conftest.py index 57d5d5cca..539ea848e 100644 --- a/conftest.py +++ b/conftest.py @@ -64,10 +64,10 @@ def mock_celery_tasks(): def _mock_runner_task( *, page_cls: typing.Type[BasePage], run_id: str, uid: str, **kwargs ): - sr = page_cls.run_doc_sr(run_id, uid) + sr = page_cls.get_sr_from_ids(run_id, uid) sr.set(sr.parent.to_dict()) sr.save() - channel = page_cls().realtime_channel_name(run_id, uid) + channel = page_cls.realtime_channel_name(run_id, uid) _mock_realtime_push(channel, sr.to_dict()) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 2233a0803..e4d58e2a2 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -8,6 +8,7 @@ import uuid from copy import deepcopy, copy from enum import Enum +from functools import cached_property from itertools import pairwise from random import Random from time import sleep @@ -151,10 +152,13 @@ class RequestModel(BaseModel): def __init__( self, - tab: RecipeTabs = "", - request: Request | SimpleNamespace = None, - run_user: AppUser = None, + *, + tab: RecipeTabs = RecipeTabs.run, + request: Request | SimpleNamespace | None = None, + run_user: AppUser | None = None, ): + if request is None: + request = SimpleNamespace(user=None, query_params={}) self.tab = tab self.request = request self.run_user = run_user @@ -164,9 +168,8 @@ def __init__( def endpoint(cls) -> str: return f"/v2/{cls.slug_versions[0]}" - @classmethod def current_app_url( - cls, + self, tab: RecipeTabs = RecipeTabs.run, *, query_params: dict = None, @@ -174,8 +177,8 @@ def current_app_url( ) -> str: if query_params is None: query_params = {} - example_id, run_id, uid = extract_query_params(gui.get_query_params()) - return cls.app_url( + example_id, run_id, uid = extract_query_params(self.request.query_params) + return self.app_url( tab=tab, example_id=example_id, run_id=run_id, @@ -209,7 +212,7 @@ def app_url( run_slug = None if example_id: try: - pr = cls.get_published_run(published_run_id=example_id) + pr = cls.get_pr_from_example_id(example_id=example_id) except PublishedRun.DoesNotExist: pr = None if pr and pr.title: @@ -225,11 +228,6 @@ def app_url( ) ) - @classmethod - def current_api_url(cls) -> furl | None: - pr = cls.get_current_published_run() - return cls.api_url(example_id=pr and pr.published_run_id) - @classmethod def api_url( cls, @@ -276,12 +274,12 @@ def sentry_event_set_request(self, event, hint): ) else: request["url"] = self.app_url( - tab=self.tab, query_params=gui.get_query_params() + tab=self.tab, query_params=dict(self.request.query_params) ) return event def sentry_event_set_user(self, event, hint): - if user := self.request and self.request.user: + if user := self.request.user: event["user"] = { "id": user.id, "name": user.display_name, @@ -305,7 +303,7 @@ def sentry_event_set_user(self, event, hint): return event def refresh_state(self): - sr = self.get_current_sr() + sr = self.current_sr channel = self.realtime_channel_name(sr.run_id, sr.uid) output = gui.realtime_pull([channel])[0] if output: @@ -341,14 +339,11 @@ def render(self): self._render_header() def _render_header(self): - current_run = self.get_current_sr() - published_run = self.get_current_published_run() - is_example = published_run.saved_run == current_run - is_root_example = is_example and published_run.is_root() - tbreadcrumbs = get_title_breadcrumbs( - self, current_run, published_run, tab=self.tab - ) - can_save = self.can_user_save_run(current_run, published_run) + sr, pr = self.current_sr_pr + is_example = pr.saved_run == sr + is_root_example = is_example and pr.is_root() + tbreadcrumbs = get_title_breadcrumbs(self, sr, pr, tab=self.tab) + can_save = self.can_user_save_run(sr, pr) request_changed = self._has_request_changed() with gui.div(className="d-flex justify-content-between mt-4"): @@ -360,15 +355,13 @@ def _render_header(self): with gui.tag("div", className="me-3 mb-1 mb-lg-0 py-2 py-lg-0"): render_breadcrumbs( tbreadcrumbs, - is_api_call=( - current_run.is_api_call and self.tab == RecipeTabs.run - ), + is_api_call=(sr.is_api_call and self.tab == RecipeTabs.run), ) if is_example: - author = published_run.created_by + author = pr.created_by else: - author = self.run_user or current_run.get_creator() + author = self.run_user or sr.get_creator() if not is_root_example: self.render_author(author) @@ -389,10 +382,7 @@ def _render_header(self): show_save_buttons = request_changed or can_save if show_save_buttons: - self._render_published_run_save_buttons( - current_run=current_run, - published_run=published_run, - ) + self._render_published_run_save_buttons(sr=sr, pr=pr) self._render_social_buttons(show_button_text=not show_save_buttons) if tbreadcrumbs.has_breadcrumbs() or self.run_user: @@ -401,15 +391,15 @@ def _render_header(self): if self.tab != RecipeTabs.run: return - if published_run and published_run.notes: - gui.write(published_run.notes, line_clamp=2) + if pr and pr.notes: + gui.write(pr.notes, line_clamp=2) elif is_root_example and self.tab != RecipeTabs.integrations: - gui.write(self.preview_description(current_run.to_dict()), line_clamp=2) + gui.write(self.preview_description(sr.to_dict()), line_clamp=2) def can_user_save_run( self, current_run: SavedRun, - published_run: PublishedRun | None, + published_run: PublishedRun, ) -> bool: return ( self.is_current_user_admin() @@ -425,13 +415,9 @@ def can_user_save_run( ) ) - def can_user_edit_published_run( - self, published_run: PublishedRun | None = None - ) -> bool: - published_run = published_run or self.get_current_published_run() + def can_user_edit_published_run(self, published_run: PublishedRun) -> bool: return self.is_current_user_admin() or bool( - published_run - and self.request + self.request and self.request.user and published_run.created_by_id and published_run.created_by_id == self.request.user.id @@ -460,13 +446,8 @@ def _render_social_buttons(self, show_button_text: bool = False): className="mb-0 ms-lg-2", ) - def _render_published_run_save_buttons( - self, - *, - current_run: SavedRun, - published_run: PublishedRun, - ): - can_edit = self.can_user_edit_published_run(published_run) + def _render_published_run_save_buttons(self, *, sr: SavedRun, pr: PublishedRun): + can_edit = self.can_user_edit_published_run(pr) with gui.div(className="d-flex justify-content-end"): gui.html( @@ -497,8 +478,8 @@ def _render_published_run_save_buttons( if options_modal.is_open(): with options_modal.container(style={"minWidth": "min(300px, 100vw)"}): self._render_options_modal( - current_run=current_run, - published_run=published_run, + current_run=sr, + published_run=pr, modal=options_modal, ) @@ -518,8 +499,8 @@ def _render_published_run_save_buttons( if publish_modal.is_open(): with publish_modal.container(style={"minWidth": "min(500px, 100vw)"}): self._render_publish_modal( - current_run=current_run, - published_run=published_run, + sr=sr, + pr=pr, modal=publish_modal, is_update_mode=can_edit, ) @@ -527,12 +508,12 @@ def _render_published_run_save_buttons( def _render_publish_modal( self, *, - current_run: SavedRun, - published_run: PublishedRun, + sr: SavedRun, + pr: PublishedRun, modal: gui.Modal, is_update_mode: bool = False, ): - if published_run.is_root() and self.is_current_user_admin(): + if pr.is_root() and self.is_current_user_admin(): with gui.div(className="text-danger"): gui.write( "###### You're about to update the root workflow as an admin. " @@ -564,7 +545,7 @@ def _render_publish_modal( "", options=options, format_func=options.__getitem__, - value=str(published_run.visibility), + value=str(pr.visibility), ) ) ) @@ -579,9 +560,9 @@ def _render_publish_modal( with gui.div(className="mt-4"): if is_update_mode: - title = published_run.title or self.title + title = pr.title or self.title else: - recipe_title = self.get_root_published_run().title or self.title + recipe_title = self.get_root_pr().title or self.title title = f"{self.request.user.first_name_possesive()} {recipe_title}" published_run_title = gui.text_input( "##### Title", @@ -591,11 +572,7 @@ def _render_publish_modal( published_run_notes = gui.text_area( "##### Notes", key="published_run_notes", - value=( - published_run.notes - or self.preview_description(gui.session_state) - or "" - ), + value=(pr.notes or self.preview_description(gui.session_state) or ""), ) with gui.div(className="mt-4 d-flex justify-content-center"): @@ -605,12 +582,12 @@ def _render_publish_modal( type="primary", ) - self._render_admin_options(current_run, published_run) + self._render_admin_options(sr, pr) if not pressed_save: return - is_root_published_run = is_update_mode and published_run.is_root() + is_root_published_run = is_update_mode and pr.is_root() if not is_root_published_run: try: self._validate_published_run_title(published_run_title) @@ -619,33 +596,31 @@ def _render_publish_modal( return if self._has_request_changed(): - current_run = self.on_submit() - if not current_run: + sr = self.on_submit() + if not sr: modal.close() if is_update_mode: updates = dict( - saved_run=current_run, + saved_run=sr, title=published_run_title.strip(), notes=published_run_notes.strip(), visibility=published_run_visibility, ) - if not self._has_published_run_changed( - published_run=published_run, **updates - ): + if not self._has_published_run_changed(published_run=pr, **updates): gui.error("No changes to publish", icon="⚠️") return - published_run.add_version(user=self.request.user, **updates) + pr.add_version(user=self.request.user, **updates) else: - published_run = self.create_published_run( + pr = self.create_published_run( published_run_id=get_random_doc_id(), - saved_run=current_run, + saved_run=sr, user=self.request.user, title=published_run_title.strip(), notes=published_run_notes.strip(), visibility=published_run_visibility, ) - raise gui.RedirectException(published_run.get_app_url()) + raise gui.RedirectException(pr.get_app_url()) def _validate_published_run_title(self, title: str): if slugify(title) in settings.DISALLOWED_TITLE_SLUGS: @@ -813,7 +788,7 @@ def _render_admin_options(self, current_run: SavedRun, published_run: PublishedR className="text-danger", ) if gui.button("👌 Yes, Update the Root Workflow"): - root_run = self.get_root_published_run() + root_run = self.get_root_pr() root_run.add_version( user=self.request.user, title=published_run.title, @@ -825,7 +800,7 @@ def _render_admin_options(self, current_run: SavedRun, published_run: PublishedR @classmethod def get_recipe_title(cls) -> str: - return cls.get_root_published_run().title or cls.title or cls.workflow.label + return cls.get_root_pr().title or cls.title or cls.workflow.label def get_explore_image(self) -> str: meta = self.workflow.get_or_create_metadata() @@ -853,7 +828,7 @@ def get_tabs(self): def render_selected_tab(self): match self.tab: case RecipeTabs.run: - if self.get_current_sr().retention_policy == RetentionPolicy.delete: + if self.current_sr.retention_policy == RetentionPolicy.delete: self.render_deleted_output() return @@ -884,15 +859,12 @@ def render_selected_tab(self): self._saved_tab() def _render_version_history(self): - published_run = self.get_current_published_run() - - if published_run: - versions = published_run.versions.all() - first_version = versions[0] - for version, older_version in pairwise(versions): - first_version = older_version - self._render_version_row(version, older_version) - self._render_version_row(first_version, None) + versions = self.current_pr.versions.all() + first_version = versions[0] + for version, older_version in pairwise(versions): + first_version = older_version + self._render_version_row(version, older_version) + self._render_version_row(first_version, None) def _render_version_row( self, @@ -957,7 +929,7 @@ def render_related_workflows(self): def _render(page_cls: typing.Type[BasePage]): page = page_cls() - root_run = page.get_root_published_run() + root_run = page.get_root_pr() state = root_run.saved_run.to_dict() preview_image = page.get_explore_image() @@ -1034,11 +1006,9 @@ def render_report_form(self): gui.error("Reason for report cannot be empty") return - example_id, run_id, uid = extract_query_params(gui.get_query_params()) - send_reported_run_email( user=self.request.user, - run_uid=uid, + run_uid=str(self.run_user.uid), url=self.current_app_url(), recipe_name=self.title, report_type=report_type, @@ -1047,7 +1017,7 @@ def render_report_form(self): ) if report_type == inappropriate_radio_text: - self.update_flag_for_run(run_id=run_id, uid=uid, is_flagged=True) + self.update_flag_for_run(is_flagged=True) # gui.success("Reported.") gui.session_state["show_report_workflow"] = False @@ -1064,10 +1034,8 @@ def _check_if_flagged(self): if not unflag_pressed: return with gui.spinner("Removing flag..."): - example_id, run_id, uid = extract_query_params(gui.get_query_params()) - if run_id and uid: - self.update_flag_for_run(run_id=run_id, uid=uid, is_flagged=False) - gui.success("Removed flag.", icon="✅") + self.update_flag_for_run(is_flagged=False) + gui.success("Removed flag.") sleep(2) gui.rerun() else: @@ -1077,87 +1045,47 @@ def _check_if_flagged(self): # Return and Don't render the run any further gui.stop() - @classmethod - def get_runs_from_query_params( - cls, example_id: str, run_id: str, uid: str - ) -> tuple[SavedRun, PublishedRun | None]: - if run_id and uid: - sr = cls.run_doc_sr(run_id, uid) - pr = sr.parent_published_run() - else: - pr = cls.get_published_run(published_run_id=example_id or "") - sr = pr.saved_run - return sr, pr + def update_flag_for_run(self, is_flagged: bool): + sr = self.current_sr + sr.is_flagged = is_flagged + sr.save(update_fields=["is_flagged"]) + gui.session_state["is_flagged"] = is_flagged - @classmethod - def get_current_published_run(cls) -> PublishedRun | None: - example_id, run_id, uid = extract_query_params(gui.get_query_params()) - return cls.get_pr_from_query_params(example_id, run_id, uid) + @property + def current_sr(self) -> SavedRun: + return self.current_sr_pr[0] - @classmethod - def get_pr_from_query_params( - cls, example_id: str, run_id: str, uid: str - ) -> PublishedRun | None: - if run_id and uid: - sr = cls.get_sr_from_query_params(example_id, run_id, uid) - return sr.parent_published_run() or cls.get_root_published_run() - elif example_id: - return cls.get_published_run(published_run_id=example_id) - else: - return cls.get_root_published_run() + @property + def current_pr(self) -> PublishedRun: + return self.current_sr_pr[1] - @classmethod - def get_published_run(cls, *, published_run_id: str): - return PublishedRun.objects.get( - workflow=cls.workflow, - published_run_id=published_run_id, + @cached_property + def current_sr_pr(self) -> tuple[SavedRun, PublishedRun]: + return self.get_sr_pr_from_query_params( + *extract_query_params(self.request.query_params) ) @classmethod - def get_current_sr(cls) -> SavedRun: - return cls.get_sr_from_query_params_dict(gui.get_query_params()) - - @classmethod - def get_sr_from_query_params_dict(cls, query_params) -> SavedRun: - example_id, run_id, uid = extract_query_params(query_params) - return cls.get_sr_from_query_params(example_id, run_id, uid) - - @classmethod - def get_sr_from_query_params( - cls, example_id: str | None, run_id: str | None, uid: str | None - ) -> SavedRun: - try: - if run_id and uid: - sr = cls.run_doc_sr(run_id, uid) - elif example_id: - pr = cls.get_published_run(published_run_id=example_id) - assert ( - pr.saved_run is not None - ), "invalid published run: without a saved run" - sr = pr.saved_run - else: - sr = cls.recipe_doc_sr() - return sr - except (SavedRun.DoesNotExist, PublishedRun.DoesNotExist): - raise HTTPException(status_code=404) - - @classmethod - def get_total_runs(cls) -> int: - # TODO: fix to also handle published run case - return SavedRun.objects.filter(workflow=cls.workflow).count() - - @classmethod - def recipe_doc_sr(cls, create: bool = True) -> SavedRun: - if create: - return cls.get_root_published_run().saved_run + def get_sr_pr_from_query_params( + cls, example_id: str, run_id: str, uid: str + ) -> tuple[SavedRun, PublishedRun]: + if run_id and uid: + sr = cls.get_sr_from_ids(run_id, uid) + pr = sr.parent_published_run() or cls.get_root_pr() else: - return cls.get_root_published_run().saved_run + if example_id: + pr = cls.get_pr_from_example_id(example_id=example_id) + else: + pr = cls.get_root_pr() + sr = pr.saved_run + return sr, pr @classmethod - def run_doc_sr( + def get_sr_from_ids( cls, run_id: str, uid: str, + *, create: bool = False, defaults: dict = None, ) -> SavedRun: @@ -1168,7 +1096,14 @@ def run_doc_sr( return SavedRun.objects.get(**config) @classmethod - def get_root_published_run(cls) -> PublishedRun: + def get_pr_from_example_id(cls, *, example_id: str): + return PublishedRun.objects.get( + workflow=cls.workflow, + published_run_id=example_id, + ) + + @classmethod + def get_root_pr(cls) -> PublishedRun: return PublishedRun.objects.get_or_create_with_version( workflow=cls.workflow, published_run_id="", @@ -1219,6 +1154,11 @@ def duplicate_published_run( visibility=visibility, ) + @classmethod + def get_total_runs(cls) -> int: + # TODO: fix to also handle published run case + return SavedRun.objects.filter(workflow=cls.workflow).count() + def render_description(self): pass @@ -1352,7 +1292,7 @@ def _render_step_row(self): with col2: placeholder = gui.div() render_called_functions( - saved_run=self.get_current_sr(), trigger=FunctionTrigger.pre + saved_run=self.current_sr, trigger=FunctionTrigger.pre ) try: self.render_steps() @@ -1362,7 +1302,7 @@ def _render_step_row(self): with placeholder: gui.write("##### 👣 Steps") render_called_functions( - saved_run=self.get_current_sr(), trigger=FunctionTrigger.post + saved_run=self.current_sr, trigger=FunctionTrigger.post ) def _render_help(self): @@ -1443,9 +1383,10 @@ def run_v2( raise NotImplementedError def _render_report_button(self): - example_id, run_id, uid = extract_query_params(gui.get_query_params()) - # only logged in users can report a run (but not examples/default runs) - if not (self.request.user and run_id and uid): + sr, pr = self.current_sr_pr + is_example = pr.saved_run_id == sr.id + # only logged in users can report a run (but not examples/root runs) + if not self.request.user or is_example: return reported = gui.button( @@ -1457,12 +1398,6 @@ def _render_report_button(self): gui.session_state["show_report_workflow"] = reported gui.rerun() - def update_flag_for_run(self, run_id: str, uid: str, is_flagged: bool): - ref = self.run_doc_sr(uid=uid, run_id=run_id) - ref.is_flagged = is_flagged - ref.save(update_fields=["is_flagged"]) - gui.session_state["is_flagged"] = is_flagged - # Functions in every recipe feels like overkill for now, hide it in settings functions_in_settings = True show_settings = True @@ -1613,8 +1548,7 @@ def on_submit(self): def should_submit_after_login(self) -> bool: return ( - gui.get_query_params().get(SUBMIT_AFTER_LOGIN_Q) - and self.request + self.request.query_params.get(SUBMIT_AFTER_LOGIN_Q) and self.request.user and not self.request.user.is_anonymous ) @@ -1643,19 +1577,13 @@ def create_new_run( run_id = get_random_doc_id() - parent_example_id, parent_run_id, parent_uid = extract_query_params( - gui.get_query_params() - ) - parent = self.get_sr_from_query_params( - parent_example_id, parent_run_id, parent_uid - ) - published_run = self.get_current_published_run() + parent, pr = self.current_sr_pr try: - parent_version = published_run and published_run.versions.latest() + parent_version = pr.versions.latest() except PublishedRunVersion.DoesNotExist: parent_version = None - sr = self.run_doc_sr( + sr = self.get_sr_from_ids( run_id, uid, create=True, @@ -1693,7 +1621,7 @@ def call_runner_task(self, sr: SavedRun, deduct_credits: bool = True): ) @classmethod - def realtime_channel_name(cls, run_id, uid): + def realtime_channel_name(cls, run_id: str, uid: str) -> str: return f"gooey-outputs/{cls.slug_versions[0]}/{uid}/{run_id}" def generate_credit_error_message(self, run_id, uid) -> str: @@ -1845,7 +1773,7 @@ def _history_tab(self): if self.is_current_user_admin(): uid = self.request.query_params.get("uid", uid) - before = gui.get_query_params().get("updated_at__lt", None) + before = self.request.query_params.get("updated_at__lt", None) if before: before = datetime.datetime.fromisoformat(before) else: @@ -2047,11 +1975,10 @@ def run_as_api_tab(self): as_async = gui.checkbox("##### Run Async") as_form_data = gui.checkbox("##### Upload Files via Form Data") - pr = self.get_current_published_run() api_url, request_body = self.get_example_request( gui.session_state, include_all=include_all, - pr=pr, + pr=self.current_pr, ) response_body = self.get_example_response_body( gui.session_state, as_async=as_async, include_all=include_all @@ -2101,9 +2028,7 @@ def ensure_credits_and_auto_recharge(self, sr: SavedRun, state: dict): raise InsufficientCredits(self.request.user, sr) def deduct_credits(self, state: dict) -> tuple[AppUserTransaction, int]: - assert ( - self.request and self.request.user - ), "request.user must be set to deduct credits" + assert self.request.user, "request.user must be set to deduct credits" amount = self.get_price_roundoff(state) txn = self.request.user.add_balance(-amount, f"gooey_in_{uuid.uuid1()}") @@ -2120,7 +2045,7 @@ def get_raw_price(self, state: dict) -> float: def get_total_linked_usage_cost_in_credits(self, default=1): """Return the sum of the linked usage costs in gooey credits.""" - sr = self.get_current_sr() + sr = self.current_sr total = sr.usage_costs.aggregate(total=Sum("dollar_amount"))["total"] if not total: return default @@ -2128,10 +2053,8 @@ def get_total_linked_usage_cost_in_credits(self, default=1): def get_grouped_linked_usage_cost_in_credits(self): """Return the linked usage costs grouped by model name in gooey credits.""" - qs = ( - self.get_current_sr() - .usage_costs.values("pricing__model_name") - .annotate(total=Sum("dollar_amount") * settings.ADDON_CREDITS_PER_DOLLAR) + qs = self.current_sr.usage_costs.values("pricing__model_name").annotate( + total=Sum("dollar_amount") * settings.ADDON_CREDITS_PER_DOLLAR ) return {item["pricing__model_name"]: item["total"] for item in qs} @@ -2175,7 +2098,7 @@ def get_example_response_body( run_id=run_id, uid=self.request.user and self.request.user.uid, ) - sr = self.get_current_sr() + sr = self.current_sr output = sr.api_output(extract_model_fields(self.ResponseModel, state)) if as_async: return dict( @@ -2206,17 +2129,13 @@ def is_user_admin(cls, user: AppUser) -> bool: return email and email in settings.ADMIN_EMAILS def is_current_user_admin(self) -> bool: - if not self.request or not self.request.user: - return False - return self.is_user_admin(self.request.user) + return self.request.user and self.is_user_admin(self.request.user) def is_current_user_paying(self) -> bool: - return bool(self.request and self.request.user and self.request.user.is_paying) + return bool(self.request.user and self.request.user.is_paying) def is_current_user_owner(self) -> bool: - return bool( - self.request and self.request.user and self.run_user == self.request.user - ) + return bool(self.request.user and self.run_user == self.request.user) def started_at_text(dt: datetime.datetime): diff --git a/daras_ai_v2/bot_integration_widgets.py b/daras_ai_v2/bot_integration_widgets.py index 15d89b4d6..ccf864121 100644 --- a/daras_ai_v2/bot_integration_widgets.py +++ b/daras_ai_v2/bot_integration_widgets.py @@ -7,6 +7,7 @@ from django.db import transaction from django.utils.text import slugify from furl import furl +from starlette.requests import Request from app_users.models import AppUser from bots.models import BotIntegration, BotIntegrationAnalysisRun, Platform @@ -54,7 +55,7 @@ def integrations_welcome_screen(title: str): gui.caption("Analyze your usage. Update your Saved Run to test changes.") -def general_integration_settings(bi: BotIntegration, current_user: AppUser): +def general_integration_settings(bi: BotIntegration, request: Request): if gui.session_state.get(f"_bi_reset_{bi.id}"): gui.session_state[f"_bi_streaming_enabled_{bi.id}"] = ( BotIntegration._meta.get_field("streaming_enabled").default @@ -101,9 +102,10 @@ def general_integration_settings(bi: BotIntegration, current_user: AppUser): "📊 View Results", str( furl( - VideoBotsPage.current_app_url( - RecipeTabs.integrations, + VideoBotsPage.app_url( + tab=RecipeTabs.integrations, path_params=dict(integration_id=bi.api_integration_id()), + query_params=dict(request.query_params), ) ) / "analysis/" @@ -119,7 +121,7 @@ def render_workflow_url_input(key: str, del_key: str | None, d: dict): key=key, internal_state=d, del_key=del_key, - current_user=current_user, + current_user=request.user, ) if not ret: return diff --git a/daras_ai_v2/bots.py b/daras_ai_v2/bots.py index 90fef04ee..eea020936 100644 --- a/daras_ai_v2/bots.py +++ b/daras_ai_v2/bots.py @@ -440,7 +440,7 @@ def _process_and_send_msg( # wait for the celery task to finish get_celery_result_db_safe(result) # get the final state from db - sr = page.run_doc_sr(run_id, uid) + sr = page.current_sr state = sr.to_dict() bot.recipe_run_state = page.get_run_state(state) bot.run_status = state.get(StateKeys.run_status) or "" diff --git a/daras_ai_v2/doc_search_settings_widgets.py b/daras_ai_v2/doc_search_settings_widgets.py index 9b46c9cdf..186075646 100644 --- a/daras_ai_v2/doc_search_settings_widgets.py +++ b/daras_ai_v2/doc_search_settings_widgets.py @@ -129,7 +129,7 @@ def doc_extract_selector(current_user: AppUser | None): gui.write("###### Create Synthetic Data") gui.caption( f""" - To improve answer quality, pick a [synthetic data maker workflow]({DocExtractPage.get_root_published_run().get_app_url()}) to scan & OCR any images in your documents or transcribe & translate any videos. It also can synthesize a helpful FAQ. Adds ~2 minutes of one-time processing per file. + To improve answer quality, pick a [synthetic data maker workflow]({DocExtractPage.get_root_pr().get_app_url()}) to scan & OCR any images in your documents or transcribe & translate any videos. It also can synthesize a helpful FAQ. Adds ~2 minutes of one-time processing per file. """ ) workflow_url_input( diff --git a/daras_ai_v2/meta_content.py b/daras_ai_v2/meta_content.py index c3335434b..d2c255a62 100644 --- a/daras_ai_v2/meta_content.py +++ b/daras_ai_v2/meta_content.py @@ -17,11 +17,8 @@ def build_meta_tags( url: str, page: "BasePage", state: dict, - run_id: str, - uid: str, - example_id: str, ) -> list[dict]: - sr, pr = page.get_runs_from_query_params(example_id, run_id, uid) + sr, pr = page.current_sr_pr metadata = page.workflow.get_or_create_metadata() title = meta_title_for_page( diff --git a/daras_ai_v2/safety_checker.py b/daras_ai_v2/safety_checker.py index 84c14baf5..8a8d67336 100644 --- a/daras_ai_v2/safety_checker.py +++ b/daras_ai_v2/safety_checker.py @@ -29,7 +29,7 @@ def safety_checker_text(text_input: str): # run in a thread to avoid messing up threadlocals result, sr = ( CompareLLMPage() - .get_published_run(published_run_id=settings.SAFETY_CHECKER_EXAMPLE_ID) + .get_pr_from_example_id(example_id=settings.SAFETY_CHECKER_EXAMPLE_ID) .submit_api_call( current_user=billing_account, request_body=dict(variables=dict(input=text_input)), diff --git a/daras_ai_v2/workflow_url_input.py b/daras_ai_v2/workflow_url_input.py index 0b82cd2a8..b9f23cc56 100644 --- a/daras_ai_v2/workflow_url_input.py +++ b/daras_ai_v2/workflow_url_input.py @@ -136,7 +136,7 @@ def url_to_runs( assert match, "Not a valid Gooey.AI URL" page_cls = page_slug_map[normalize_slug(match.matched_params["page_slug"])] example_id, run_id, uid = extract_query_params(furl(url).query.params) - sr, pr = page_cls.get_runs_from_query_params( + sr, pr = page_cls.get_sr_pr_from_query_params( example_id or match.matched_params.get("example_id"), run_id, uid ) return page_cls, sr, pr @@ -177,7 +177,7 @@ def get_published_run_options( if include_root: # include root recipe if requested options_dict = { - page_cls.get_root_published_run().get_app_url(): "Default", + page_cls.get_root_pr().get_app_url(): "Default", } | options_dict return options_dict diff --git a/explore.py b/explore.py index cd4ec3d00..5bbe380fc 100644 --- a/explore.py +++ b/explore.py @@ -85,7 +85,7 @@ def render_description(page: BasePage): with gui.link(to=page.app_url()): gui.markdown(f"#### {page.get_recipe_title()}") - root_pr = page.get_root_published_run() + root_pr = page.get_root_pr() notes = root_pr.notes or page.preview_description(state=page.sane_defaults) with gui.tag("p", style={"marginBottom": "25px"}): gui.write(notes, line_clamp=4) diff --git a/recipes/DocSearch.py b/recipes/DocSearch.py index 3bbccef25..08275f6f8 100644 --- a/recipes/DocSearch.py +++ b/recipes/DocSearch.py @@ -136,7 +136,7 @@ def render_settings(self): gui.write("---") gui.write("##### 🔎 Document Search Settings") citation_style_selector() - doc_extract_selector(self.request and self.request.user) + doc_extract_selector(self.request.user) query_instructions_widget() gui.write("---") doc_search_advanced_settings() @@ -175,7 +175,7 @@ def run_v2( "search_query": response.final_search_query, }, ), - current_user=self.request and self.request.user, + current_user=self.request.user, ) # empty search result, abort! diff --git a/recipes/Functions.py b/recipes/Functions.py index 356381343..81b99c946 100644 --- a/recipes/Functions.py +++ b/recipes/Functions.py @@ -59,10 +59,8 @@ def run_v2( request: "FunctionsPage.RequestModel", response: "FunctionsPage.ResponseModel", ) -> typing.Iterator[str | None]: - query_params = gui.get_query_params() - run_id = query_params.get("run_id") - uid = query_params.get("uid") - tag = f"run_id={run_id}&uid={uid}" + sr = self.current_sr + tag = f"run_id={sr.run_id}&uid={sr.uid}" yield "Running your code..." # this will run functions/executor.js in deno deploy @@ -86,7 +84,7 @@ def render_form_v2(self): ) def get_price_roundoff(self, state: dict) -> float: - if CalledFunction.objects.filter(function_run=self.get_current_sr()).exists(): + if CalledFunction.objects.filter(function_run=self.current_sr).exists(): return 0 return super().get_price_roundoff(state) diff --git a/recipes/GoogleGPT.py b/recipes/GoogleGPT.py index 6a9eaf999..611483287 100644 --- a/recipes/GoogleGPT.py +++ b/recipes/GoogleGPT.py @@ -254,7 +254,7 @@ def run_v2( }, ), is_user_url=False, - current_user=self.request and self.request.user, + current_user=self.request.user, ) # add pretty titles to references for ref in response.references: diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 41e7a33a8..26b7e785e 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -85,7 +85,6 @@ from daras_ai_v2.prompt_vars import render_prompt_vars from daras_ai_v2.pydantic_validation import FieldHttpUrl from daras_ai_v2.query_generator import generate_final_search_query -from daras_ai_v2.query_params_util import extract_query_params from daras_ai_v2.search_ref import ( parse_refs, CitationStyles, @@ -521,7 +520,7 @@ def render_settings(self): citation_style_selector() gui.checkbox("🔗 Shorten Citation URLs", key="use_url_shortener") - doc_extract_selector(self.request and self.request.user) + doc_extract_selector(self.request.user) gui.write("---") @@ -886,7 +885,7 @@ def run_v2( "keyword_query": response.final_keyword_query, }, ), - current_user=self.request and self.request.user, + current_user=self.request.user, ) if request.use_url_shortener: for reference in response.references: @@ -1061,9 +1060,7 @@ def render_integrations_tab(self): gui.anchor("Get Started", href=self.get_auth_url(), type="primary") return - sr, pr = self.get_runs_from_query_params( - *extract_query_params(gui.get_query_params()) - ) + sr, pr = self.current_sr_pr # make user the user knows that they are on a saved run not the published run if pr and pr.saved_run_id != sr.id: @@ -1377,7 +1374,7 @@ def render_integrations_settings( slack_specific_settings(bi, run_title) if bi.platform == Platform.TWILIO: twilio_specific_settings(bi) - general_integration_settings(bi, self.request.user) + general_integration_settings(bi, self.request) if bi.platform in [Platform.SLACK, Platform.WHATSAPP, Platform.TWILIO]: gui.newline() diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py index 648a2894f..3f3bbacb6 100644 --- a/recipes/VideoBotsStats.py +++ b/recipes/VideoBotsStats.py @@ -84,8 +84,9 @@ def show_title_breadcrumb_share( ) gui.breadcrumb_item( "Integrations", - link_to=VideoBotsPage.current_app_url( - RecipeTabs.integrations, + link_to=VideoBotsPage.app_url( + tab=RecipeTabs.integrations, + query_params=dict(self.request.query_params), path_params=dict( integration_id=bi.api_integration_id() ), @@ -152,7 +153,7 @@ def render(self): ) ) - run_url = VideoBotsPage.current_app_url() + run_url = VideoBotsPage.app_url(query_params=dict(self.request.query_params)) if bi.published_run_id: run_title = bi.published_run.title else: diff --git a/routers/api.py b/routers/api.py index 9b795d426..dd74b5a00 100644 --- a/routers/api.py +++ b/routers/api.py @@ -258,8 +258,13 @@ def get_run_status( run_id: str, user: AppUser = Depends(api_auth_header), ): - self = page_cls() - sr = self.get_sr_from_query_params(example_id=None, run_id=run_id, uid=user.uid) + # init a new page for every request + self = page_cls( + request=SimpleNamespace( + user=user, query_params=dict(run_id=run_id, uid=user.uid) + ) + ) + sr = self.current_sr web_url = str(furl(self.app_url(run_id=run_id, uid=user.uid))) ret = { "run_id": run_id, @@ -335,18 +340,17 @@ def submit_api_call( deduct_credits: bool = True, ) -> tuple[BasePage, "celery.result.AsyncResult", str, str]: # init a new page for every request - self = page_cls(request=SimpleNamespace(user=user)) + query_params.setdefault("uid", user.uid) + self = page_cls(request=SimpleNamespace(user=user, query_params=query_params)) # get saved state from db - query_params.setdefault("uid", user.uid) - sr = self.get_sr_from_query_params_dict(query_params) + sr = self.current_sr state = self.load_state_from_sr(sr) # load request data state.update(request_body) # set streamlit session state gui.set_session_state(state) - gui.set_query_params(query_params) # create a new run try: @@ -369,7 +373,7 @@ def build_async_api_response(*, page: BasePage, run_id: str, uid: str) -> dict: / page.endpoint.replace("v2", "v3") / "status/" ) - sr = page.run_doc_sr(run_id, uid) + sr = page.current_sr return dict( run_id=run_id, web_url=web_url, @@ -388,7 +392,8 @@ def build_sync_api_response( web_url = page.app_url(run_id=run_id, uid=uid) # wait for the result get_celery_result_db_safe(result) - sr = page.run_doc_sr(run_id, uid) + sr = page.current_sr + sr.refresh_from_db() if sr.retention_policy == RetentionPolicy.delete: sr.state = {} sr.save(update_fields=["state"]) diff --git a/routers/bots_api.py b/routers/bots_api.py index 780a7b918..64285cbd2 100644 --- a/routers/bots_api.py +++ b/routers/bots_api.py @@ -302,7 +302,7 @@ def runner(self): msg_handler(self) # raise ValueError("Stream ended") if self.run_id and self.uid: - sr = self.page_cls.run_doc_sr(run_id=self.run_id, uid=self.uid) + sr = self.page_cls.get_sr_from_ids(run_id=self.run_id, uid=self.uid) state = sr.to_dict() self.queue.put( FinalResponse( diff --git a/routers/root.py b/routers/root.py index b3a636207..5d31b1ceb 100644 --- a/routers/root.py +++ b/routers/root.py @@ -39,7 +39,6 @@ from daras_ai_v2.meta_content import build_meta_tags, raw_build_meta_tags from daras_ai_v2.meta_preview_url import meta_preview_url from daras_ai_v2.profiles import user_profile_page, get_meta_tags_for_profile -from daras_ai_v2.query_params_util import extract_query_params from daras_ai_v2.settings import templates from handles.models import Handle from routers.custom_api_router import CustomAPIRouter @@ -314,7 +313,7 @@ def _api_docs_page(request): as_form_data = gui.checkbox("Upload Files via Form Data") page = workflow.page_cls(request=request) - state = page.get_root_published_run().saved_run.to_dict() + state = page.get_root_pr().saved_run.to_dict() api_url, request_body = page.get_example_request(state, include_all=include_all) response_body = page.get_example_response_body( state, as_async=as_async, include_all=include_all @@ -669,12 +668,13 @@ def render_recipe_page( return RedirectResponse(str(new_url.set(origin=None)), status_code=301) # this is because the code still expects example_id to be in the query params - gui.set_query_params(dict(request.query_params) | dict(example_id=example_id)) - _, run_id, uid = extract_query_params(request.query_params) + request._query_params = dict(request.query_params) | dict(example_id=example_id) + + page = page_cls(tab=tab, request=request) + sr = page.current_sr + page.run_user = get_run_user(request, sr.uid) - page = page_cls(tab=tab, request=request, run_user=get_run_user(request, uid)) if not gui.session_state: - sr = page.get_sr_from_query_params(example_id, run_id, uid) gui.session_state.update(page.load_state_from_sr(sr)) with page_wrapper(request): @@ -682,12 +682,7 @@ def render_recipe_page( return dict( meta=build_meta_tags( - url=get_og_url_path(request), - page=page, - state=gui.session_state, - run_id=run_id, - uid=uid, - example_id=example_id, + url=get_og_url_path(request), page=page, state=gui.session_state ), ) @@ -698,7 +693,7 @@ def get_og_url_path(request) -> str: ) -def get_run_user(request, uid) -> AppUser | None: +def get_run_user(request: Request, uid: str) -> AppUser | None: if not uid: return if request.user and request.user.uid == uid: diff --git a/routers/twilio_api.py b/routers/twilio_api.py index 7893c15d1..0223636a9 100644 --- a/routers/twilio_api.py +++ b/routers/twilio_api.py @@ -262,7 +262,7 @@ def resp_say_or_tts_play( tts_state = TextToSpeechPage.RequestModel.parse_obj( {**bot.saved_run.state, "text_prompt": text} ).dict() - result, sr = TextToSpeechPage.get_root_published_run().submit_api_call( + result, sr = TextToSpeechPage.get_root_pr().submit_api_call( current_user=AppUser.objects.get(uid=bot.billing_account_uid), request_body=tts_state, ) diff --git a/tests/test_apis.py b/tests/test_apis.py index fa897eb83..7220798e3 100644 --- a/tests/test_apis.py +++ b/tests/test_apis.py @@ -21,7 +21,7 @@ def test_apis_sync(mock_celery_tasks, force_authentication, threadpool_subtest): def _test_api_sync(page_cls: typing.Type[BasePage]): - state = page_cls.recipe_doc_sr().state + state = page_cls.get_root_pr().saved_run.state r = client.post( f"/v2/{page_cls.slug_versions[0]}/", json=page_cls.get_example_request(state)[1], @@ -38,7 +38,7 @@ def test_apis_async(mock_celery_tasks, force_authentication, threadpool_subtest) def _test_api_async(page_cls: typing.Type[BasePage]): - state = page_cls.recipe_doc_sr().state + state = page_cls.get_root_pr().saved_run.state r = client.post( f"/v3/{page_cls.slug_versions[0]}/async/", diff --git a/tests/test_pricing.py b/tests/test_pricing.py index 6ac6e0591..73e05e4a9 100644 --- a/tests/test_pricing.py +++ b/tests/test_pricing.py @@ -1,3 +1,5 @@ +from types import SimpleNamespace + import gooey_gui as gui import pytest from starlette.testclient import TestClient @@ -46,8 +48,12 @@ def test_copilot_get_raw_price_round_up(): unit_quantity=model_pricing.unit_quantity, dollar_amount=model_pricing.unit_cost * 1 / model_pricing.unit_quantity, ) - copilot_page = VideoBotsPage(run_user=user) - gui.set_query_params({"run_id": bot_saved_run.run_id or "", "uid": user.uid or ""}) + copilot_page = VideoBotsPage( + request=SimpleNamespace( + user=user, + query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""), + ), + ) assert ( copilot_page.get_price_roundoff(state=state) == 210 + copilot_page.PROFIT_CREDITS @@ -107,8 +113,12 @@ def test_multiple_llm_sums_usage_cost(): dollar_amount=model_pricing2.unit_cost * 1 / model_pricing2.unit_quantity, ) - llm_page = CompareLLMPage(run_user=user) - gui.set_query_params({"run_id": bot_saved_run.run_id or "", "uid": user.uid or ""}) + llm_page = CompareLLMPage( + request=SimpleNamespace( + user=user, + query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""), + ) + ) assert llm_page.get_price_roundoff(state=state) == (310 + llm_page.PROFIT_CREDITS) @@ -152,8 +162,12 @@ def test_workflowmetadata_2x_multiplier(): metadata.price_multiplier = 2 metadata.save() - llm_page = CompareLLMPage(run_user=user) - gui.set_query_params({"run_id": bot_saved_run.run_id or "", "uid": user.uid or ""}) + llm_page = CompareLLMPage( + request=SimpleNamespace( + user=user, + query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""), + ) + ) assert ( llm_page.get_price_roundoff(state=state) == (210 + llm_page.PROFIT_CREDITS) * 2 ) diff --git a/url_shortener/models.py b/url_shortener/models.py index 21b0864d7..515b293a2 100644 --- a/url_shortener/models.py +++ b/url_shortener/models.py @@ -6,10 +6,9 @@ from app_users.models import AppUser from bots.custom_fields import CustomURLField from bots.models import Workflow, SavedRun +from celeryapp.tasks import get_running_saved_run from daras_ai.image_input import truncate_filename from daras_ai_v2 import settings -from daras_ai_v2.query_params_util import extract_query_params -import gooey_gui as gui class ShortenedURLQuerySet(models.QuerySet): @@ -17,14 +16,8 @@ def get_or_create_for_workflow( self, *, user: AppUser, workflow: Workflow, **kwargs ) -> tuple["ShortenedURL", bool]: surl, created = self.filter_first_or_create(user=user, **kwargs) - _, run_id, uid = extract_query_params(gui.get_query_params()) - surl.saved_runs.add( - SavedRun.objects.get_or_create( - workflow=workflow, - run_id=run_id, - uid=uid, - )[0], - ) + sr = get_running_saved_run() + surl.saved_runs.add(sr) return surl, created def filter_first_or_create(self, defaults=None, **kwargs): diff --git a/usage_costs/cost_utils.py b/usage_costs/cost_utils.py index 6596c0d22..b380ee63f 100644 --- a/usage_costs/cost_utils.py +++ b/usage_costs/cost_utils.py @@ -1,19 +1,16 @@ from loguru import logger -from daras_ai_v2.query_params_util import extract_query_params +from celeryapp.tasks import get_running_saved_run from usage_costs.models import ( UsageCost, ModelSku, ModelPricing, ) -import gooey_gui as gui def record_cost_auto(model: str, sku: ModelSku, quantity: int): - from bots.models import SavedRun - - _, run_id, uid = extract_query_params(gui.get_query_params()) - if not run_id or not uid: + sr = get_running_saved_run() + if not sr: return try: @@ -22,10 +19,8 @@ def record_cost_auto(model: str, sku: ModelSku, quantity: int): logger.warning(f"Cant find pricing for {model=} {sku=}: {e=}") return - saved_run = SavedRun.objects.get(run_id=run_id, uid=uid) - UsageCost.objects.create( - saved_run=saved_run, + saved_run=sr, pricing=pricing, quantity=quantity, unit_cost=pricing.unit_cost,