diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index a06dbeddb..d3a52175d 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -1,4 +1,5 @@
import datetime
+import hashlib
import html
import inspect
import json
@@ -139,8 +140,7 @@ class RequestModel(BaseModel):
functions: list[RecipeFunction] | None = Field(
title="🧩 Developer Tools and Functions",
)
- variables: dict[str, typing.Any] = Field(
- None,
+ variables: dict[str, typing.Any] | None = Field(
title="⌥ Variables",
description="Variables to be used as Jinja prompt templates and in functions as arguments",
)
@@ -305,12 +305,8 @@ def sentry_event_set_user(self, event, hint):
return event
def refresh_state(self):
- example_id, run_id, uid = extract_query_params(gui.get_query_params())
- if not run_id:
- sr = self.get_sr_from_query_params(example_id, run_id, uid)
- run_id, uid = sr.run_id, sr.uid
-
- channel = self.realtime_channel_name(run_id, uid)
+ sr = self.get_current_sr()
+ channel = self.realtime_channel_name(sr.run_id, sr.uid)
output = gui.realtime_pull([channel])[0]
if output:
gui.session_state.update(output)
@@ -330,7 +326,7 @@ def render(self):
self.render_report_form()
return
- self._render_header()
+ header_placeholder = gui.div()
gui.newline()
with gui.nav_tabs():
@@ -341,14 +337,19 @@ def render(self):
with gui.nav_tab_content():
self.render_selected_tab()
+ with header_placeholder:
+ self._render_header()
+
def _render_header(self):
current_run = self.get_current_sr()
published_run = self.get_current_published_run()
- is_example = published_run and published_run.saved_run == current_run
+ is_example = published_run.saved_run == current_run
is_root_example = is_example and published_run.is_root()
tbreadcrumbs = get_title_breadcrumbs(
self, current_run, published_run, tab=self.tab
)
+ can_edit = self.can_user_edit_run(current_run, published_run)
+ request_changed = self._has_request_changed()
with gui.div(className="d-flex justify-content-between mt-4"):
with gui.div(className="d-lg-flex d-block align-items-center"):
@@ -365,7 +366,6 @@ def _render_header(self):
)
if is_example:
- assert published_run
author = published_run.created_by
else:
author = self.run_user or current_run.get_creator()
@@ -373,37 +373,27 @@ def _render_header(self):
self.render_author(author)
with gui.div(className="d-flex align-items-center"):
- can_user_edit_run = self.can_user_edit_run(current_run, published_run)
- has_unpublished_changes = (
- published_run
- and published_run.saved_run != current_run
- and self.request
- and self.request.user
- )
-
- if (
- can_user_edit_run and has_unpublished_changes
- ) or self._has_current_run_changed(current_run):
+ if request_changed or (can_edit and not is_example):
self._render_unpublished_changes_indicator()
with gui.div(className="d-flex align-items-start right-action-icons"):
gui.html(
"""
-
- """
+
+ """
)
- if published_run and can_user_edit_run:
- self._render_published_run_buttons(
+ show_save_buttons = request_changed or can_edit
+ if show_save_buttons:
+ self._render_published_run_save_buttons(
current_run=current_run,
published_run=published_run,
)
-
- self._render_social_buttons(show_button_text=not can_user_edit_run)
+ self._render_social_buttons(show_button_text=not show_save_buttons)
if tbreadcrumbs.has_breadcrumbs() or self.run_user:
# only render title here if the above row was not empty
@@ -457,11 +447,10 @@ def _render_unpublished_changes_indicator(self):
gui.html("Unpublished changes")
def _render_social_buttons(self, show_button_text: bool = False):
- button_text = (
- ' Copy Link'
- if show_button_text
- else ""
- )
+ if show_button_text:
+ button_text = ' Copy Link'
+ else:
+ button_text = ""
copy_to_clipboard_button(
f'{button_text}',
@@ -470,12 +459,11 @@ def _render_social_buttons(self, show_button_text: bool = False):
className="mb-0 ms-lg-2",
)
- def _render_published_run_buttons(
+ def _render_published_run_save_buttons(
self,
*,
current_run: SavedRun,
published_run: PublishedRun,
- redirect_to: str | None = None,
):
is_update_mode = (
self.is_current_user_admin()
@@ -533,7 +521,6 @@ def _render_published_run_buttons(
published_run=published_run,
modal=publish_modal,
is_update_mode=is_update_mode,
- redirect_to=redirect_to,
)
def _render_publish_modal(
@@ -543,10 +530,7 @@ def _render_publish_modal(
published_run: PublishedRun,
modal: gui.Modal,
is_update_mode: bool = False,
- redirect_to: str | None = None,
):
- is_example = published_run.saved_run == current_run
-
if published_run.is_root() and self.is_current_user_admin():
with gui.div(className="text-danger"):
gui.write(
@@ -633,10 +617,10 @@ def _render_publish_modal(
gui.error(str(e))
return
- if self._has_current_run_changed(current_run):
- sr = self._on_submit()
- if sr:
- current_run = sr
+ if self._has_request_changed():
+ current_run = self.on_submit()
+ if not current_run:
+ modal.close()
if is_update_mode:
updates = dict(
@@ -660,13 +644,7 @@ def _render_publish_modal(
notes=published_run_notes.strip(),
visibility=published_run_visibility,
)
-
- if redirect_to:
- raise gui.RedirectException(redirect_to)
- elif is_example:
- modal.close() # implicit gui.rerun to reload the updated run
- else:
- raise gui.RedirectException(published_run.get_app_url())
+ raise gui.RedirectException(published_run.get_app_url())
def _validate_published_run_title(self, title: str):
if slugify(title) in settings.DISALLOWED_TITLE_SLUGS:
@@ -696,15 +674,23 @@ def _has_published_run_changed(
or published_run.saved_run != saved_run
)
- def _has_current_run_changed(self, sr: SavedRun) -> bool:
- """are there unsaved changes that haven't been run?"""
+ def _has_request_changed(self) -> bool:
+ if gui.session_state.get("--has-request-changed"):
+ return True
+
try:
- extracted_state = self.RequestModel.parse_obj(gui.session_state)
- extracted_sr = self.RequestModel.parse_obj(sr.to_dict())
- return extracted_sr != extracted_state
- except ValidationError as e:
- # don't want page to be inaccessible if ever validation fails - log and continue
- sentry_sdk.capture_exception(e)
+ curr_req = self.RequestModel.parse_obj(gui.session_state)
+ except ValidationError:
+ # if the request model fails to parse, the request has likely changed
+ return True
+
+ curr_hash = hashlib.md5(curr_req.json(sort_keys=True).encode()).hexdigest()
+ prev_hash = gui.session_state.setdefault("--prev-request-hash", curr_hash)
+
+ if curr_hash != prev_hash:
+ gui.session_state["--has-request-changed"] = True # cache it for next time
+ return True
+ else:
return False
def _render_options_modal(
@@ -1542,7 +1528,7 @@ def _render_output_col(self, *, submitted: bool = False, is_deleted: bool = Fals
submitted = True
if submitted or self.should_submit_after_login():
- self.on_submit()
+ self.submit_and_redirect()
run_state = self.get_run_state(gui.session_state)
match run_state:
@@ -1604,12 +1590,13 @@ def render_extra_waiting_output(self):
def estimate_run_duration(self) -> int | None:
pass
- def on_submit(self):
- sr = self._on_submit()
- if sr:
- raise gui.RedirectException(self.app_url(run_id=sr.run_id, uid=sr.uid))
+ def submit_and_redirect(self):
+ sr = self.on_submit()
+ if not sr:
+ return
+ raise gui.RedirectException(self.app_url(run_id=sr.run_id, uid=sr.uid))
- def _on_submit(self):
+ def on_submit(self):
try:
sr = self.create_new_run(enable_rate_limits=True)
except ValidationError as e:
@@ -1620,9 +1607,7 @@ def _on_submit(self):
gui.session_state[StateKeys.run_status] = None
gui.session_state[StateKeys.error_msg] = e.detail.get("error", "")
return
-
self.call_runner_task(sr)
-
return sr
def should_submit_after_login(self) -> bool:
diff --git a/daras_ai_v2/prompt_vars.py b/daras_ai_v2/prompt_vars.py
index c315da996..b81932533 100644
--- a/daras_ai_v2/prompt_vars.py
+++ b/daras_ai_v2/prompt_vars.py
@@ -55,7 +55,9 @@ def render_title_desc():
- set(gui.session_state.keys()) # dont show other session state variables
)
- gui.session_state[key] = new_vars = {}
+ new_vars = {}
+ if all_var_names:
+ gui.session_state[key] = new_vars
title_shown = False
for name in sorted(all_var_names):
var_key = f"--{key}:{name}"
diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py
index edc3cadf3..41e7a33a8 100644
--- a/recipes/VideoBots.py
+++ b/recipes/VideoBots.py
@@ -662,7 +662,7 @@ def on_send(
gui.session_state["input_images"] = new_input_images or None
gui.session_state["input_documents"] = new_input_documents or None
- self.on_submit()
+ self.submit_and_redirect()
def render_steps(self):
if gui.session_state.get("tts_provider"):