Skip to content

Commit

Permalink
Add countdown timer component
Browse files Browse the repository at this point in the history
Additionally, also refactor base.py for readability and clear run states (waiting / success / error / recipe_root). This PR also adds a render_extra_waiting_output method on base.py, that can be updated separately for each recipe.
  • Loading branch information
nikochiko committed Nov 14, 2023
1 parent eaec380 commit 060805c
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 21 deletions.
4 changes: 2 additions & 2 deletions bots/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,9 @@ def to_dict(self) -> dict:

ret = self.state.copy()
if self.updated_at:
ret[StateKeys.updated_at] = self.updated_at
ret[StateKeys.updated_at] = self.updated_at.isoformat()
if self.created_at:
ret[StateKeys.created_at] = self.created_at
ret[StateKeys.created_at] = self.created_at.isoformat()
if self.error_msg:
ret[StateKeys.error_msg] = self.error_msg
if self.run_time:
Expand Down
81 changes: 62 additions & 19 deletions daras_ai_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,15 @@ def api_url(self, example_id=None, run_id=None, uid=None) -> furl:
def endpoint(self) -> str:
return f"/v2/{self.slug_versions[0]}/"

def get_tab_url(self, tab: str) -> str:
example_id, run_id, uid = extract_query_params(gooey_get_query_params())
return self.app_url(
example_id=example_id,
run_id=run_id,
uid=uid,
tab_name=MenuTabs.paths[tab],
)

def render(self):
with sentry_sdk.configure_scope() as scope:
scope.set_extra("base_url", self.app_url())
Expand Down Expand Up @@ -191,10 +200,7 @@ def render(self):
with st.nav_tabs():
tab_names = self.get_tabs()
for name in tab_names:
url = self.app_url(
*extract_query_params(gooey_get_query_params()),
tab_name=MenuTabs.paths[name],
)
url = self.get_tab_url(name)
with st.nav_item(url, active=name == selected_tab):
st.html(name)
with st.nav_tab_content():
Expand Down Expand Up @@ -634,6 +640,19 @@ def _render_input_col(self):
submitted = self.render_submit_button()
return submitted

def get_run_state(
self,
) -> typing.Literal["success", "error", "waiting", "recipe_root"]:
if st.session_state.get(StateKeys.run_status):
return "waiting"
elif st.session_state.get(StateKeys.error_msg):
return "error"
elif st.session_state.get(StateKeys.run_time):
return "success"
else:
# when user is at a recipe root, and not running anything
return "recipe_root"

def _render_output_col(self, submitted: bool):
assert inspect.isgeneratorfunction(self.run)

Expand All @@ -647,27 +666,40 @@ def _render_output_col(self, submitted: bool):

self._render_before_output()

run_status = st.session_state.get(StateKeys.run_status)
if run_status:
st.caption("Your changes are saved in the above URL. Save it for later!")
html_spinner(run_status)
else:
err_msg = st.session_state.get(StateKeys.error_msg)
run_time = st.session_state.get(StateKeys.run_time, 0)

# render errors
if err_msg is not None:
st.error(err_msg)
# render run time
elif run_time:
st.success(f"Success! Run Time: `{run_time:.2f}` seconds.")
run_state = self.get_run_state()
match run_state:
case "success":
self._render_success_output()
case "error":
self._render_error_output()
case "waiting":
self._render_waiting_output()
case "recipe_root":
pass

# render outputs
self.render_output()

if not run_status:
if run_state != "waiting":
self._render_after_output()

def _render_success_output(self):
run_time = st.session_state.get(StateKeys.run_time, 0)
st.success(f"Success! Run Time: `{run_time:.2f}` seconds.")

def _render_error_output(self):
err_msg = st.session_state.get(StateKeys.error_msg)
st.error(err_msg)

def _render_waiting_output(self):
run_status = st.session_state.get(StateKeys.run_status)
st.caption("Your changes are saved in the above URL. Save it for later!")
html_spinner(run_status)
self.render_extra_waiting_output()

def render_extra_waiting_output(self):
pass

def on_submit(self):
example_id, run_id, uid = self.create_new_run()
if settings.CREDITS_TO_DEDUCT_PER_RUN and not self.check_credits():
Expand Down Expand Up @@ -1150,6 +1182,17 @@ def is_current_user_admin(self) -> bool:
def is_current_user_paying(self) -> bool:
return bool(self.request and self.request.user and self.request.user.is_paying)

def is_current_user_owner(self) -> bool:
"""
Did the current user create this run?
"""
return bool(
self.request
and self.request.user
and self.run_user
and self.request.user.uid == self.run_user.uid
)


def get_example_request_body(
request_model: typing.Type[BaseModel],
Expand Down
12 changes: 12 additions & 0 deletions gooey_ui/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import math
import textwrap
import typing
from datetime import datetime, timezone

import numpy as np

Expand Down Expand Up @@ -32,6 +33,17 @@ def dummy(*args, **kwargs):
dataframe = dummy


def countdown_timer(
end_time: datetime,
delay_text: str,
) -> state.NestingCtx:
return _node(
"countdown-timer",
endTime=end_time.astimezone(timezone.utc).isoformat(),
delayText=delay_text,
)


def nav_tabs():
return _node("nav-tabs")

Expand Down
22 changes: 22 additions & 0 deletions recipes/DeforumSD.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing
import uuid
from datetime import datetime, timedelta

from django.db.models import TextChoices
from pydantic import BaseModel
Expand All @@ -12,6 +13,7 @@
from daras_ai_v2.gpu_server import call_celery_task_outfile
from daras_ai_v2.loom_video_widget import youtube_video
from daras_ai_v2.safety_checker import safety_checker
from daras_ai_v2.tabs_widget import MenuTabs


class AnimationModels(TextChoices):
Expand All @@ -27,6 +29,7 @@ class _AnimationPrompt(TypedDict):
AnimationPrompts = list[_AnimationPrompt]

CREDITS_PER_FRAME = 1.5
MODEL_ESTIMATED_TIME_PER_FRAME = 2.4 # seconds


def input_prompt_to_animation_prompts(input_prompt: str):
Expand Down Expand Up @@ -417,6 +420,25 @@ def render_output(self):
st.write("Output Video")
st.video(output_video, autoplay=True)

def render_extra_waiting_output(self):
if created_at := st.session_state.get("created_at"):
start_time = datetime.fromisoformat(created_at)
with st.countdown_timer(
end_time=start_time + timedelta(seconds=self.estimate_run_duration()),
delay_text="Sorry for the wait. Your run is taking longer than we expected.",
):
if self.is_current_user_owner() and self.request.user.email:
st.write(
f"""We'll email **{self.request.user.email}** when your workflow is done."""
)
st.write(
f"""In the meantime, check out [🚀 Examples]({self.get_tab_url(MenuTabs.examples)}) for inspiration."""
)

def estimate_run_duration(self):
# in seconds
return st.session_state.get("max_frames", 100) * MODEL_ESTIMATED_TIME_PER_FRAME

def render_example(self, state: dict):
display = self.preview_input(state)
st.markdown("```lua\n" + display + "\n```")
Expand Down

0 comments on commit 060805c

Please sign in to comment.