diff --git a/bots/admin.py b/bots/admin.py index 3a4ae8e20..cfa06bc7c 100644 --- a/bots/admin.py +++ b/bots/admin.py @@ -437,7 +437,10 @@ def rerun_tasks(self, request, queryset): sr: SavedRun for sr in queryset.all(): page = Workflow(sr.workflow).page_cls( - request=SimpleNamespace(user=AppUser.objects.get(uid=sr.uid)) + request=SimpleNamespace( + user=AppUser.objects.get(uid=sr.uid), + query_params=dict(run_id=sr.run_id, uid=sr.uid), + ) ) page.call_runner_task(sr, deduct_credits=False) self.message_user( diff --git a/bots/models.py b/bots/models.py index 2566e976f..35867532f 100644 --- a/bots/models.py +++ b/bots/models.py @@ -389,7 +389,7 @@ def submit_api_call( ), ) - return result, page.run_doc_sr(run_id, uid) + return result, page.get_sr_pr()[0] def get_creator(self) -> AppUser | None: if self.uid: diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index 221754ddb..c6d21f87d 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -1,5 +1,6 @@ import datetime import html +import threading import traceback import typing from time import time @@ -31,6 +32,15 @@ DEFAULT_RUN_STATUS = "Running..." +threadlocal = threading.local() + + +def get_current_saved_run() -> SavedRun | None: + try: + return threadlocal.saved_run + except AttributeError: + return None + @app.task def runner_task( @@ -81,12 +91,16 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False # save to db page.dump_state_to_sr(gui.session_state | output, sr) - user = AppUser.objects.get(id=user_id) - page = page_cls(request=SimpleNamespace(user=user)) + page = page_cls( + request=SimpleNamespace( + user=AppUser.objects.get(id=user_id), + query_params=dict(run_id=run_id, uid=uid), + ), + ) page.setup_sentry() - sr = page.run_doc_sr(run_id, uid) + sr = page.get_sr_pr()[0] + threadlocal.saved_run = sr gui.set_session_state(sr.to_dict() | (unsaved_state or {})) - gui.set_query_params(dict(run_id=run_id, uid=uid)) try: save_on_step() @@ -114,6 +128,7 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False # save everything, mark run as completed finally: save_on_step(done=True) + threadlocal.saved_run = None post_runner_tasks.delay(sr.id) diff --git a/conftest.py b/conftest.py index 57d5d5cca..e4ce8edfd 100644 --- a/conftest.py +++ b/conftest.py @@ -64,10 +64,10 @@ def mock_celery_tasks(): def _mock_runner_task( *, page_cls: typing.Type[BasePage], run_id: str, uid: str, **kwargs ): - sr = page_cls.run_doc_sr(run_id, uid) + sr = page_cls.get_saved_run(run_id, uid) sr.set(sr.parent.to_dict()) sr.save() - channel = page_cls().realtime_channel_name(run_id, uid) + channel = page_cls.realtime_channel_name(run_id, uid) _mock_realtime_push(channel, sr.to_dict()) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index d3a52175d..033df43d4 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -151,10 +151,13 @@ class RequestModel(BaseModel): def __init__( self, - tab: RecipeTabs = "", - request: Request | SimpleNamespace = None, - run_user: AppUser = None, + *, + tab: RecipeTabs = RecipeTabs.run, + request: Request | SimpleNamespace | None = None, + run_user: AppUser | None = None, ): + if request is None: + request = SimpleNamespace(user=None, query_params={}) self.tab = tab self.request = request self.run_user = run_user @@ -164,9 +167,8 @@ def __init__( def endpoint(cls) -> str: return f"/v2/{cls.slug_versions[0]}" - @classmethod def current_app_url( - cls, + self, tab: RecipeTabs = RecipeTabs.run, *, query_params: dict = None, @@ -174,8 +176,8 @@ def current_app_url( ) -> str: if query_params is None: query_params = {} - example_id, run_id, uid = extract_query_params(gui.get_query_params()) - return cls.app_url( + example_id, run_id, uid = extract_query_params(self.request.query_params) + return self.app_url( tab=tab, example_id=example_id, run_id=run_id, @@ -225,11 +227,6 @@ def app_url( ) ) - @classmethod - def current_api_url(cls) -> furl | None: - pr = cls.get_current_published_run() - return cls.api_url(example_id=pr and pr.published_run_id) - @classmethod def api_url( cls, @@ -276,12 +273,12 @@ def sentry_event_set_request(self, event, hint): ) else: request["url"] = self.app_url( - tab=self.tab, query_params=gui.get_query_params() + tab=self.tab, query_params=dict(self.request.query_params) ) return event def sentry_event_set_user(self, event, hint): - if user := self.request and self.request.user: + if user := self.request.user: event["user"] = { "id": user.id, "name": user.display_name, @@ -305,7 +302,7 @@ def sentry_event_set_user(self, event, hint): return event def refresh_state(self): - sr = self.get_current_sr() + sr = self.get_sr_pr()[0] channel = self.realtime_channel_name(sr.run_id, sr.uid) output = gui.realtime_pull([channel])[0] if output: @@ -341,14 +338,11 @@ def render(self): self._render_header() def _render_header(self): - current_run = self.get_current_sr() - published_run = self.get_current_published_run() - is_example = published_run.saved_run == current_run - is_root_example = is_example and published_run.is_root() - tbreadcrumbs = get_title_breadcrumbs( - self, current_run, published_run, tab=self.tab - ) - can_edit = self.can_user_edit_run(current_run, published_run) + sr, pr = self.get_sr_pr() + is_example = pr.saved_run == sr + is_root_example = is_example and pr.is_root() + tbreadcrumbs = get_title_breadcrumbs(self, sr, pr, tab=self.tab) + can_edit = self.can_user_edit_run(sr, pr) request_changed = self._has_request_changed() with gui.div(className="d-flex justify-content-between mt-4"): @@ -360,15 +354,13 @@ def _render_header(self): with gui.tag("div", className="me-3 mb-1 mb-lg-0 py-2 py-lg-0"): render_breadcrumbs( tbreadcrumbs, - is_api_call=( - current_run.is_api_call and self.tab == RecipeTabs.run - ), + is_api_call=(sr.is_api_call and self.tab == RecipeTabs.run), ) if is_example: - author = published_run.created_by + author = pr.created_by else: - author = self.run_user or current_run.get_creator() + author = self.run_user or sr.get_creator() if not is_root_example: self.render_author(author) @@ -390,8 +382,8 @@ def _render_header(self): show_save_buttons = request_changed or can_edit if show_save_buttons: self._render_published_run_save_buttons( - current_run=current_run, - published_run=published_run, + current_run=sr, + published_run=pr, ) self._render_social_buttons(show_button_text=not show_save_buttons) @@ -401,15 +393,15 @@ def _render_header(self): if self.tab != RecipeTabs.run: return - if published_run and published_run.notes: - gui.write(published_run.notes, line_clamp=2) + if pr and pr.notes: + gui.write(pr.notes, line_clamp=2) elif is_root_example and self.tab != RecipeTabs.integrations: - gui.write(self.preview_description(current_run.to_dict()), line_clamp=2) + gui.write(self.preview_description(sr.to_dict()), line_clamp=2) def can_user_edit_run( self, current_run: SavedRun, - published_run: PublishedRun | None, + published_run: PublishedRun, ) -> bool: return ( self.is_current_user_admin() @@ -425,13 +417,9 @@ def can_user_edit_run( ) ) - def can_user_edit_published_run( - self, published_run: PublishedRun | None = None - ) -> bool: - published_run = published_run or self.get_current_published_run() + def can_user_edit_published_run(self, published_run: PublishedRun) -> bool: return self.is_current_user_admin() or bool( - published_run - and self.request + self.request and self.request.user and published_run.created_by == self.request.user ) @@ -852,7 +840,7 @@ def get_tabs(self): def render_selected_tab(self): match self.tab: case RecipeTabs.run: - if self.get_current_sr().retention_policy == RetentionPolicy.delete: + if self.get_sr_pr()[0].retention_policy == RetentionPolicy.delete: self.render_deleted_output() return @@ -883,15 +871,13 @@ def render_selected_tab(self): self._saved_tab() def _render_version_history(self): - published_run = self.get_current_published_run() - - if published_run: - versions = published_run.versions.all() - first_version = versions[0] - for version, older_version in pairwise(versions): - first_version = older_version - self._render_version_row(version, older_version) - self._render_version_row(first_version, None) + published_run = self.get_sr_pr()[1] + versions = published_run.versions.all() + first_version = versions[0] + for version, older_version in pairwise(versions): + first_version = older_version + self._render_version_row(version, older_version) + self._render_version_row(first_version, None) def _render_version_row( self, @@ -1033,11 +1019,9 @@ def render_report_form(self): gui.error("Reason for report cannot be empty") return - example_id, run_id, uid = extract_query_params(gui.get_query_params()) - send_reported_run_email( user=self.request.user, - run_uid=uid, + run_uid=str(self.run_user.uid), url=self.current_app_url(), recipe_name=self.title, report_type=report_type, @@ -1046,7 +1030,7 @@ def render_report_form(self): ) if report_type == inappropriate_radio_text: - self.update_flag_for_run(run_id=run_id, uid=uid, is_flagged=True) + self.update_flag_for_run(is_flagged=True) # gui.success("Reported.") gui.session_state["show_report_workflow"] = False @@ -1063,10 +1047,8 @@ def _check_if_flagged(self): if not unflag_pressed: return with gui.spinner("Removing flag..."): - example_id, run_id, uid = extract_query_params(gui.get_query_params()) - if run_id and uid: - self.update_flag_for_run(run_id=run_id, uid=uid, is_flagged=False) - gui.success("Removed flag.", icon="✅") + self.update_flag_for_run(is_flagged=False) + gui.success("Removed flag.") sleep(2) gui.rerun() else: @@ -1076,84 +1058,38 @@ def _check_if_flagged(self): # Return and Don't render the run any further gui.stop() - @classmethod - def get_runs_from_query_params( - cls, example_id: str, run_id: str, uid: str - ) -> tuple[SavedRun, PublishedRun | None]: - if run_id and uid: - sr = cls.run_doc_sr(run_id, uid) - pr = sr.parent_published_run() - else: - pr = cls.get_published_run(published_run_id=example_id or "") - sr = pr.saved_run - return sr, pr + def update_flag_for_run(self, is_flagged: bool): + sr = self.get_sr_pr()[0] + sr.is_flagged = is_flagged + sr.save(update_fields=["is_flagged"]) + gui.session_state["is_flagged"] = is_flagged - @classmethod - def get_current_published_run(cls) -> PublishedRun | None: - example_id, run_id, uid = extract_query_params(gui.get_query_params()) - return cls.get_pr_from_query_params(example_id, run_id, uid) + _current_runs: tuple[SavedRun, PublishedRun] | None = None + + def get_sr_pr(self) -> tuple[SavedRun, PublishedRun]: + if not self._current_runs: + self._current_runs = self.get_sr_pr_from_query_params( + *extract_query_params(self.request.query_params) + ) + return self._current_runs @classmethod - def get_pr_from_query_params( + def get_sr_pr_from_query_params( cls, example_id: str, run_id: str, uid: str - ) -> PublishedRun | None: + ) -> tuple[SavedRun, PublishedRun]: if run_id and uid: - sr = cls.get_sr_from_query_params(example_id, run_id, uid) - return sr.parent_published_run() - elif example_id: - return cls.get_published_run(published_run_id=example_id) + sr = cls.get_saved_run(run_id, uid) + pr = sr.parent_published_run() else: - return cls.get_root_published_run() - - @classmethod - def get_published_run(cls, *, published_run_id: str): - return PublishedRun.objects.get( - workflow=cls.workflow, - published_run_id=published_run_id, - ) - - @classmethod - def get_current_sr(cls) -> SavedRun: - return cls.get_sr_from_query_params_dict(gui.get_query_params()) - - @classmethod - def get_sr_from_query_params_dict(cls, query_params) -> SavedRun: - example_id, run_id, uid = extract_query_params(query_params) - return cls.get_sr_from_query_params(example_id, run_id, uid) - - @classmethod - def get_sr_from_query_params( - cls, example_id: str | None, run_id: str | None, uid: str | None - ) -> SavedRun: - try: - if run_id and uid: - sr = cls.run_doc_sr(run_id, uid) - elif example_id: + if example_id: pr = cls.get_published_run(published_run_id=example_id) - assert ( - pr.saved_run is not None - ), "invalid published run: without a saved run" - sr = pr.saved_run else: - sr = cls.recipe_doc_sr() - return sr - except (SavedRun.DoesNotExist, PublishedRun.DoesNotExist): - raise HTTPException(status_code=404) - - @classmethod - def get_total_runs(cls) -> int: - # TODO: fix to also handle published run case - return SavedRun.objects.filter(workflow=cls.workflow).count() - - @classmethod - def recipe_doc_sr(cls, create: bool = True) -> SavedRun: - if create: - return cls.get_root_published_run().saved_run - else: - return cls.get_root_published_run().saved_run + pr = cls.get_root_published_run() + sr = pr.saved_run + return sr, pr @classmethod - def run_doc_sr( + def get_saved_run( cls, run_id: str, uid: str, @@ -1166,6 +1102,13 @@ def run_doc_sr( else: return SavedRun.objects.get(**config) + @classmethod + def get_published_run(cls, *, published_run_id: str): + return PublishedRun.objects.get( + workflow=cls.workflow, + published_run_id=published_run_id, + ) + @classmethod def get_root_published_run(cls) -> PublishedRun: return PublishedRun.objects.get_or_create_with_version( @@ -1218,6 +1161,11 @@ def duplicate_published_run( visibility=visibility, ) + @classmethod + def get_total_runs(cls) -> int: + # TODO: fix to also handle published run case + return SavedRun.objects.filter(workflow=cls.workflow).count() + def render_description(self): pass @@ -1351,7 +1299,7 @@ def _render_step_row(self): with col2: placeholder = gui.div() render_called_functions( - saved_run=self.get_current_sr(), trigger=FunctionTrigger.pre + saved_run=self.get_sr_pr()[0], trigger=FunctionTrigger.pre ) try: self.render_steps() @@ -1361,7 +1309,7 @@ def _render_step_row(self): with placeholder: gui.write("##### 👣 Steps") render_called_functions( - saved_run=self.get_current_sr(), trigger=FunctionTrigger.post + saved_run=self.get_sr_pr()[0], trigger=FunctionTrigger.post ) def _render_help(self): @@ -1442,9 +1390,10 @@ def run_v2( raise NotImplementedError def _render_report_button(self): - example_id, run_id, uid = extract_query_params(gui.get_query_params()) - # only logged in users can report a run (but not examples/default runs) - if not (self.request.user and run_id and uid): + sr, pr = self.get_sr_pr() + is_example = pr.saved_run_id == sr.id + # only logged in users can report a run (but not examples/root runs) + if not self.request.user or is_example: return reported = gui.button( @@ -1456,12 +1405,6 @@ def _render_report_button(self): gui.session_state["show_report_workflow"] = reported gui.rerun() - def update_flag_for_run(self, run_id: str, uid: str, is_flagged: bool): - ref = self.run_doc_sr(uid=uid, run_id=run_id) - ref.is_flagged = is_flagged - ref.save(update_fields=["is_flagged"]) - gui.session_state["is_flagged"] = is_flagged - # Functions in every recipe feels like overkill for now, hide it in settings functions_in_settings = True show_settings = True @@ -1612,8 +1555,7 @@ def on_submit(self): def should_submit_after_login(self) -> bool: return ( - gui.get_query_params().get(SUBMIT_AFTER_LOGIN_Q) - and self.request + self.request.query_params.get(SUBMIT_AFTER_LOGIN_Q) and self.request.user and not self.request.user.is_anonymous ) @@ -1642,19 +1584,13 @@ def create_new_run( run_id = get_random_doc_id() - parent_example_id, parent_run_id, parent_uid = extract_query_params( - gui.get_query_params() - ) - parent = self.get_sr_from_query_params( - parent_example_id, parent_run_id, parent_uid - ) - published_run = self.get_current_published_run() + parent, pr = self.get_sr_pr() try: - parent_version = published_run and published_run.versions.latest() + parent_version = pr.versions.latest() except PublishedRunVersion.DoesNotExist: parent_version = None - sr = self.run_doc_sr( + sr = self.get_saved_run( run_id, uid, create=True, @@ -1692,7 +1628,7 @@ def call_runner_task(self, sr: SavedRun, deduct_credits: bool = True): ) @classmethod - def realtime_channel_name(cls, run_id, uid): + def realtime_channel_name(cls, run_id: str, uid: str) -> str: return f"gooey-outputs/{cls.slug_versions[0]}/{uid}/{run_id}" def generate_credit_error_message(self, run_id, uid) -> str: @@ -1844,7 +1780,7 @@ def _history_tab(self): if self.is_current_user_admin(): uid = self.request.query_params.get("uid", uid) - before = gui.get_query_params().get("updated_at__lt", None) + before = self.request.query_params.get("updated_at__lt", None) if before: before = datetime.datetime.fromisoformat(before) else: @@ -2046,7 +1982,7 @@ def run_as_api_tab(self): as_async = gui.checkbox("##### Run Async") as_form_data = gui.checkbox("##### Upload Files via Form Data") - pr = self.get_current_published_run() + pr = self.get_sr_pr()[1] api_url, request_body = self.get_example_request( gui.session_state, include_all=include_all, @@ -2100,9 +2036,7 @@ def ensure_credits_and_auto_recharge(self, sr: SavedRun, state: dict): raise InsufficientCredits(self.request.user, sr) def deduct_credits(self, state: dict) -> tuple[AppUserTransaction, int]: - assert ( - self.request and self.request.user - ), "request.user must be set to deduct credits" + assert self.request.user, "request.user must be set to deduct credits" amount = self.get_price_roundoff(state) txn = self.request.user.add_balance(-amount, f"gooey_in_{uuid.uuid1()}") @@ -2119,7 +2053,7 @@ def get_raw_price(self, state: dict) -> float: def get_total_linked_usage_cost_in_credits(self, default=1): """Return the sum of the linked usage costs in gooey credits.""" - sr = self.get_current_sr() + sr = self.get_sr_pr()[0] total = sr.usage_costs.aggregate(total=Sum("dollar_amount"))["total"] if not total: return default @@ -2128,7 +2062,7 @@ def get_total_linked_usage_cost_in_credits(self, default=1): def get_grouped_linked_usage_cost_in_credits(self): """Return the linked usage costs grouped by model name in gooey credits.""" qs = ( - self.get_current_sr() + self.get_sr_pr()[0] .usage_costs.values("pricing__model_name") .annotate(total=Sum("dollar_amount") * settings.ADDON_CREDITS_PER_DOLLAR) ) @@ -2174,7 +2108,7 @@ def get_example_response_body( run_id=run_id, uid=self.request.user and self.request.user.uid, ) - sr = self.get_current_sr() + sr = self.get_sr_pr()[0] output = sr.api_output(extract_model_fields(self.ResponseModel, state)) if as_async: return dict( @@ -2205,17 +2139,13 @@ def is_user_admin(cls, user: AppUser) -> bool: return email and email in settings.ADMIN_EMAILS def is_current_user_admin(self) -> bool: - if not self.request or not self.request.user: - return False - return self.is_user_admin(self.request.user) + return self.request.user and self.is_user_admin(self.request.user) def is_current_user_paying(self) -> bool: - return bool(self.request and self.request.user and self.request.user.is_paying) + return bool(self.request.user and self.request.user.is_paying) def is_current_user_owner(self) -> bool: - return bool( - self.request and self.request.user and self.run_user == self.request.user - ) + return bool(self.request.user and self.run_user == self.request.user) def started_at_text(dt: datetime.datetime): diff --git a/daras_ai_v2/bot_integration_widgets.py b/daras_ai_v2/bot_integration_widgets.py index 15d89b4d6..ccf864121 100644 --- a/daras_ai_v2/bot_integration_widgets.py +++ b/daras_ai_v2/bot_integration_widgets.py @@ -7,6 +7,7 @@ from django.db import transaction from django.utils.text import slugify from furl import furl +from starlette.requests import Request from app_users.models import AppUser from bots.models import BotIntegration, BotIntegrationAnalysisRun, Platform @@ -54,7 +55,7 @@ def integrations_welcome_screen(title: str): gui.caption("Analyze your usage. Update your Saved Run to test changes.") -def general_integration_settings(bi: BotIntegration, current_user: AppUser): +def general_integration_settings(bi: BotIntegration, request: Request): if gui.session_state.get(f"_bi_reset_{bi.id}"): gui.session_state[f"_bi_streaming_enabled_{bi.id}"] = ( BotIntegration._meta.get_field("streaming_enabled").default @@ -101,9 +102,10 @@ def general_integration_settings(bi: BotIntegration, current_user: AppUser): "📊 View Results", str( furl( - VideoBotsPage.current_app_url( - RecipeTabs.integrations, + VideoBotsPage.app_url( + tab=RecipeTabs.integrations, path_params=dict(integration_id=bi.api_integration_id()), + query_params=dict(request.query_params), ) ) / "analysis/" @@ -119,7 +121,7 @@ def render_workflow_url_input(key: str, del_key: str | None, d: dict): key=key, internal_state=d, del_key=del_key, - current_user=current_user, + current_user=request.user, ) if not ret: return diff --git a/daras_ai_v2/bots.py b/daras_ai_v2/bots.py index 90fef04ee..f33f453da 100644 --- a/daras_ai_v2/bots.py +++ b/daras_ai_v2/bots.py @@ -440,7 +440,7 @@ def _process_and_send_msg( # wait for the celery task to finish get_celery_result_db_safe(result) # get the final state from db - sr = page.run_doc_sr(run_id, uid) + sr = page.get_sr_pr()[0] state = sr.to_dict() bot.recipe_run_state = page.get_run_state(state) bot.run_status = state.get(StateKeys.run_status) or "" diff --git a/daras_ai_v2/meta_content.py b/daras_ai_v2/meta_content.py index c3335434b..af0a00b3b 100644 --- a/daras_ai_v2/meta_content.py +++ b/daras_ai_v2/meta_content.py @@ -17,11 +17,8 @@ def build_meta_tags( url: str, page: "BasePage", state: dict, - run_id: str, - uid: str, - example_id: str, ) -> list[dict]: - sr, pr = page.get_runs_from_query_params(example_id, run_id, uid) + sr, pr = page.get_sr_pr() metadata = page.workflow.get_or_create_metadata() title = meta_title_for_page( diff --git a/daras_ai_v2/workflow_url_input.py b/daras_ai_v2/workflow_url_input.py index 5c5f7b2e9..20a874e91 100644 --- a/daras_ai_v2/workflow_url_input.py +++ b/daras_ai_v2/workflow_url_input.py @@ -136,7 +136,7 @@ def url_to_runs( assert match, "Not a valid Gooey.AI URL" page_cls = page_slug_map[normalize_slug(match.matched_params["page_slug"])] example_id, run_id, uid = extract_query_params(furl(url).query.params) - sr, pr = page_cls.get_runs_from_query_params( + sr, pr = page_cls.get_sr_pr_from_query_params( example_id or match.matched_params.get("example_id"), run_id, uid ) return page_cls, sr, pr diff --git a/recipes/DocSearch.py b/recipes/DocSearch.py index 3bbccef25..08275f6f8 100644 --- a/recipes/DocSearch.py +++ b/recipes/DocSearch.py @@ -136,7 +136,7 @@ def render_settings(self): gui.write("---") gui.write("##### 🔎 Document Search Settings") citation_style_selector() - doc_extract_selector(self.request and self.request.user) + doc_extract_selector(self.request.user) query_instructions_widget() gui.write("---") doc_search_advanced_settings() @@ -175,7 +175,7 @@ def run_v2( "search_query": response.final_search_query, }, ), - current_user=self.request and self.request.user, + current_user=self.request.user, ) # empty search result, abort! diff --git a/recipes/Functions.py b/recipes/Functions.py index 356381343..04448f216 100644 --- a/recipes/Functions.py +++ b/recipes/Functions.py @@ -59,10 +59,8 @@ def run_v2( request: "FunctionsPage.RequestModel", response: "FunctionsPage.ResponseModel", ) -> typing.Iterator[str | None]: - query_params = gui.get_query_params() - run_id = query_params.get("run_id") - uid = query_params.get("uid") - tag = f"run_id={run_id}&uid={uid}" + sr = self.get_sr_pr()[0] + tag = f"run_id={sr.run_id}&uid={sr.uid}" yield "Running your code..." # this will run functions/executor.js in deno deploy @@ -86,7 +84,7 @@ def render_form_v2(self): ) def get_price_roundoff(self, state: dict) -> float: - if CalledFunction.objects.filter(function_run=self.get_current_sr()).exists(): + if CalledFunction.objects.filter(function_run=self.get_sr_pr()[0]).exists(): return 0 return super().get_price_roundoff(state) diff --git a/recipes/GoogleGPT.py b/recipes/GoogleGPT.py index 6a9eaf999..611483287 100644 --- a/recipes/GoogleGPT.py +++ b/recipes/GoogleGPT.py @@ -254,7 +254,7 @@ def run_v2( }, ), is_user_url=False, - current_user=self.request and self.request.user, + current_user=self.request.user, ) # add pretty titles to references for ref in response.references: diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 41e7a33a8..0d7ace339 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -85,7 +85,6 @@ from daras_ai_v2.prompt_vars import render_prompt_vars from daras_ai_v2.pydantic_validation import FieldHttpUrl from daras_ai_v2.query_generator import generate_final_search_query -from daras_ai_v2.query_params_util import extract_query_params from daras_ai_v2.search_ref import ( parse_refs, CitationStyles, @@ -521,7 +520,7 @@ def render_settings(self): citation_style_selector() gui.checkbox("🔗 Shorten Citation URLs", key="use_url_shortener") - doc_extract_selector(self.request and self.request.user) + doc_extract_selector(self.request.user) gui.write("---") @@ -886,7 +885,7 @@ def run_v2( "keyword_query": response.final_keyword_query, }, ), - current_user=self.request and self.request.user, + current_user=self.request.user, ) if request.use_url_shortener: for reference in response.references: @@ -1061,9 +1060,7 @@ def render_integrations_tab(self): gui.anchor("Get Started", href=self.get_auth_url(), type="primary") return - sr, pr = self.get_runs_from_query_params( - *extract_query_params(gui.get_query_params()) - ) + sr, pr = self.get_sr_pr() # make user the user knows that they are on a saved run not the published run if pr and pr.saved_run_id != sr.id: @@ -1377,7 +1374,7 @@ def render_integrations_settings( slack_specific_settings(bi, run_title) if bi.platform == Platform.TWILIO: twilio_specific_settings(bi) - general_integration_settings(bi, self.request.user) + general_integration_settings(bi, self.request) if bi.platform in [Platform.SLACK, Platform.WHATSAPP, Platform.TWILIO]: gui.newline() diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py index 648a2894f..3f3bbacb6 100644 --- a/recipes/VideoBotsStats.py +++ b/recipes/VideoBotsStats.py @@ -84,8 +84,9 @@ def show_title_breadcrumb_share( ) gui.breadcrumb_item( "Integrations", - link_to=VideoBotsPage.current_app_url( - RecipeTabs.integrations, + link_to=VideoBotsPage.app_url( + tab=RecipeTabs.integrations, + query_params=dict(self.request.query_params), path_params=dict( integration_id=bi.api_integration_id() ), @@ -152,7 +153,7 @@ def render(self): ) ) - run_url = VideoBotsPage.current_app_url() + run_url = VideoBotsPage.app_url(query_params=dict(self.request.query_params)) if bi.published_run_id: run_title = bi.published_run.title else: diff --git a/routers/api.py b/routers/api.py index 9b795d426..019b77e8e 100644 --- a/routers/api.py +++ b/routers/api.py @@ -258,8 +258,13 @@ def get_run_status( run_id: str, user: AppUser = Depends(api_auth_header), ): - self = page_cls() - sr = self.get_sr_from_query_params(example_id=None, run_id=run_id, uid=user.uid) + # init a new page for every request + self = page_cls( + request=SimpleNamespace( + user=user, query_params=dict(run_id=run_id, uid=user.uid) + ) + ) + sr = self.get_sr_pr()[0] web_url = str(furl(self.app_url(run_id=run_id, uid=user.uid))) ret = { "run_id": run_id, @@ -335,18 +340,17 @@ def submit_api_call( deduct_credits: bool = True, ) -> tuple[BasePage, "celery.result.AsyncResult", str, str]: # init a new page for every request - self = page_cls(request=SimpleNamespace(user=user)) + query_params.setdefault("uid", user.uid) + self = page_cls(request=SimpleNamespace(user=user, query_params=query_params)) # get saved state from db - query_params.setdefault("uid", user.uid) - sr = self.get_sr_from_query_params_dict(query_params) + sr = self.get_sr_pr()[0] state = self.load_state_from_sr(sr) # load request data state.update(request_body) # set streamlit session state gui.set_session_state(state) - gui.set_query_params(query_params) # create a new run try: @@ -369,7 +373,7 @@ def build_async_api_response(*, page: BasePage, run_id: str, uid: str) -> dict: / page.endpoint.replace("v2", "v3") / "status/" ) - sr = page.run_doc_sr(run_id, uid) + sr = page.get_sr_pr()[0] return dict( run_id=run_id, web_url=web_url, @@ -388,7 +392,7 @@ def build_sync_api_response( web_url = page.app_url(run_id=run_id, uid=uid) # wait for the result get_celery_result_db_safe(result) - sr = page.run_doc_sr(run_id, uid) + sr = page.get_sr_pr()[0] if sr.retention_policy == RetentionPolicy.delete: sr.state = {} sr.save(update_fields=["state"]) diff --git a/routers/bots_api.py b/routers/bots_api.py index 780a7b918..f68d3d2de 100644 --- a/routers/bots_api.py +++ b/routers/bots_api.py @@ -302,7 +302,7 @@ def runner(self): msg_handler(self) # raise ValueError("Stream ended") if self.run_id and self.uid: - sr = self.page_cls.run_doc_sr(run_id=self.run_id, uid=self.uid) + sr = self.page_cls.get_saved_run(run_id=self.run_id, uid=self.uid) state = sr.to_dict() self.queue.put( FinalResponse( diff --git a/routers/root.py b/routers/root.py index 93d889b36..84447bd1e 100644 --- a/routers/root.py +++ b/routers/root.py @@ -39,7 +39,6 @@ from daras_ai_v2.meta_content import build_meta_tags, raw_build_meta_tags from daras_ai_v2.meta_preview_url import meta_preview_url from daras_ai_v2.profiles import user_profile_page, get_meta_tags_for_profile -from daras_ai_v2.query_params_util import extract_query_params from daras_ai_v2.settings import templates from handles.models import Handle from routers.custom_api_router import CustomAPIRouter @@ -654,12 +653,13 @@ def render_recipe_page( return RedirectResponse(str(new_url.set(origin=None)), status_code=301) # this is because the code still expects example_id to be in the query params - gui.set_query_params(dict(request.query_params) | dict(example_id=example_id)) - _, run_id, uid = extract_query_params(request.query_params) + request._query_params = dict(request.query_params) | dict(example_id=example_id) + + page = page_cls(tab=tab, request=request) + sr = page.get_sr_pr()[0] + page.run_user = get_run_user(request, sr.uid) - page = page_cls(tab=tab, request=request, run_user=get_run_user(request, uid)) if not gui.session_state: - sr = page.get_sr_from_query_params(example_id, run_id, uid) gui.session_state.update(page.load_state_from_sr(sr)) with page_wrapper(request): @@ -667,12 +667,7 @@ def render_recipe_page( return dict( meta=build_meta_tags( - url=get_og_url_path(request), - page=page, - state=gui.session_state, - run_id=run_id, - uid=uid, - example_id=example_id, + url=get_og_url_path(request), page=page, state=gui.session_state ), ) @@ -683,7 +678,7 @@ def get_og_url_path(request) -> str: ) -def get_run_user(request, uid) -> AppUser | None: +def get_run_user(request: Request, uid: str) -> AppUser | None: if not uid: return if request.user and request.user.uid == uid: diff --git a/tests/test_apis.py b/tests/test_apis.py index fa897eb83..fe70b9079 100644 --- a/tests/test_apis.py +++ b/tests/test_apis.py @@ -21,7 +21,7 @@ def test_apis_sync(mock_celery_tasks, force_authentication, threadpool_subtest): def _test_api_sync(page_cls: typing.Type[BasePage]): - state = page_cls.recipe_doc_sr().state + state = page_cls.get_root_published_run().saved_run.state r = client.post( f"/v2/{page_cls.slug_versions[0]}/", json=page_cls.get_example_request(state)[1], @@ -38,7 +38,7 @@ def test_apis_async(mock_celery_tasks, force_authentication, threadpool_subtest) def _test_api_async(page_cls: typing.Type[BasePage]): - state = page_cls.recipe_doc_sr().state + state = page_cls.get_root_published_run().saved_run.state r = client.post( f"/v3/{page_cls.slug_versions[0]}/async/", diff --git a/tests/test_pricing.py b/tests/test_pricing.py index 6ac6e0591..73e05e4a9 100644 --- a/tests/test_pricing.py +++ b/tests/test_pricing.py @@ -1,3 +1,5 @@ +from types import SimpleNamespace + import gooey_gui as gui import pytest from starlette.testclient import TestClient @@ -46,8 +48,12 @@ def test_copilot_get_raw_price_round_up(): unit_quantity=model_pricing.unit_quantity, dollar_amount=model_pricing.unit_cost * 1 / model_pricing.unit_quantity, ) - copilot_page = VideoBotsPage(run_user=user) - gui.set_query_params({"run_id": bot_saved_run.run_id or "", "uid": user.uid or ""}) + copilot_page = VideoBotsPage( + request=SimpleNamespace( + user=user, + query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""), + ), + ) assert ( copilot_page.get_price_roundoff(state=state) == 210 + copilot_page.PROFIT_CREDITS @@ -107,8 +113,12 @@ def test_multiple_llm_sums_usage_cost(): dollar_amount=model_pricing2.unit_cost * 1 / model_pricing2.unit_quantity, ) - llm_page = CompareLLMPage(run_user=user) - gui.set_query_params({"run_id": bot_saved_run.run_id or "", "uid": user.uid or ""}) + llm_page = CompareLLMPage( + request=SimpleNamespace( + user=user, + query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""), + ) + ) assert llm_page.get_price_roundoff(state=state) == (310 + llm_page.PROFIT_CREDITS) @@ -152,8 +162,12 @@ def test_workflowmetadata_2x_multiplier(): metadata.price_multiplier = 2 metadata.save() - llm_page = CompareLLMPage(run_user=user) - gui.set_query_params({"run_id": bot_saved_run.run_id or "", "uid": user.uid or ""}) + llm_page = CompareLLMPage( + request=SimpleNamespace( + user=user, + query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""), + ) + ) assert ( llm_page.get_price_roundoff(state=state) == (210 + llm_page.PROFIT_CREDITS) * 2 ) diff --git a/url_shortener/models.py b/url_shortener/models.py index 21b0864d7..5597fe0c5 100644 --- a/url_shortener/models.py +++ b/url_shortener/models.py @@ -6,10 +6,9 @@ from app_users.models import AppUser from bots.custom_fields import CustomURLField from bots.models import Workflow, SavedRun +from celeryapp.tasks import get_current_saved_run from daras_ai.image_input import truncate_filename from daras_ai_v2 import settings -from daras_ai_v2.query_params_util import extract_query_params -import gooey_gui as gui class ShortenedURLQuerySet(models.QuerySet): @@ -17,14 +16,8 @@ def get_or_create_for_workflow( self, *, user: AppUser, workflow: Workflow, **kwargs ) -> tuple["ShortenedURL", bool]: surl, created = self.filter_first_or_create(user=user, **kwargs) - _, run_id, uid = extract_query_params(gui.get_query_params()) - surl.saved_runs.add( - SavedRun.objects.get_or_create( - workflow=workflow, - run_id=run_id, - uid=uid, - )[0], - ) + sr = get_current_saved_run() + surl.saved_runs.add(sr) return surl, created def filter_first_or_create(self, defaults=None, **kwargs): diff --git a/usage_costs/cost_utils.py b/usage_costs/cost_utils.py index 6596c0d22..8ab27e0dd 100644 --- a/usage_costs/cost_utils.py +++ b/usage_costs/cost_utils.py @@ -1,19 +1,16 @@ from loguru import logger -from daras_ai_v2.query_params_util import extract_query_params +from celeryapp.tasks import get_current_saved_run from usage_costs.models import ( UsageCost, ModelSku, ModelPricing, ) -import gooey_gui as gui def record_cost_auto(model: str, sku: ModelSku, quantity: int): - from bots.models import SavedRun - - _, run_id, uid = extract_query_params(gui.get_query_params()) - if not run_id or not uid: + sr = get_current_saved_run() + if not sr: return try: @@ -22,10 +19,8 @@ def record_cost_auto(model: str, sku: ModelSku, quantity: int): logger.warning(f"Cant find pricing for {model=} {sku=}: {e=}") return - saved_run = SavedRun.objects.get(run_id=run_id, uid=uid) - UsageCost.objects.create( - saved_run=saved_run, + saved_run=sr, pricing=pricing, quantity=quantity, unit_cost=pricing.unit_cost,