Skip to content

Commit

Permalink
Refactor rendering in base.py and add countdown timer support
Browse files Browse the repository at this point in the history
  • Loading branch information
nikochiko committed Nov 14, 2023
1 parent eaec380 commit 0ecbc0e
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 21 deletions.
4 changes: 2 additions & 2 deletions bots/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,9 @@ def to_dict(self) -> dict:

ret = self.state.copy()
if self.updated_at:
ret[StateKeys.updated_at] = self.updated_at
ret[StateKeys.updated_at] = self.updated_at.isoformat()
if self.created_at:
ret[StateKeys.created_at] = self.created_at
ret[StateKeys.created_at] = self.created_at.isoformat()
if self.error_msg:
ret[StateKeys.error_msg] = self.error_msg
if self.run_time:
Expand Down
81 changes: 62 additions & 19 deletions daras_ai_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,15 @@ def api_url(self, example_id=None, run_id=None, uid=None) -> furl:
def endpoint(self) -> str:
return f"/v2/{self.slug_versions[0]}/"

def get_tab_url(self, tab: str) -> str:
example_id, run_id, uid = extract_query_params(gooey_get_query_params())
return self.app_url(
example_id=example_id,
run_id=run_id,
uid=uid,
tab_name=MenuTabs.paths[tab],
)

def render(self):
with sentry_sdk.configure_scope() as scope:
scope.set_extra("base_url", self.app_url())
Expand Down Expand Up @@ -191,10 +200,7 @@ def render(self):
with st.nav_tabs():
tab_names = self.get_tabs()
for name in tab_names:
url = self.app_url(
*extract_query_params(gooey_get_query_params()),
tab_name=MenuTabs.paths[name],
)
url = self.get_tab_url(name)
with st.nav_item(url, active=name == selected_tab):
st.html(name)
with st.nav_tab_content():
Expand Down Expand Up @@ -634,6 +640,19 @@ def _render_input_col(self):
submitted = self.render_submit_button()
return submitted

def get_run_state(
self,
) -> typing.Literal["success", "error", "waiting", "recipe_root"]:
if st.session_state.get(StateKeys.run_status):
return "waiting"
elif st.session_state.get(StateKeys.error_msg):
return "error"
elif st.session_state.get(StateKeys.run_time):
return "success"
else:
# when user is at a recipe root, and not running anything
return "recipe_root"

def _render_output_col(self, submitted: bool):
assert inspect.isgeneratorfunction(self.run)

Expand All @@ -647,27 +666,40 @@ def _render_output_col(self, submitted: bool):

self._render_before_output()

run_status = st.session_state.get(StateKeys.run_status)
if run_status:
st.caption("Your changes are saved in the above URL. Save it for later!")
html_spinner(run_status)
else:
err_msg = st.session_state.get(StateKeys.error_msg)
run_time = st.session_state.get(StateKeys.run_time, 0)

# render errors
if err_msg is not None:
st.error(err_msg)
# render run time
elif run_time:
st.success(f"Success! Run Time: `{run_time:.2f}` seconds.")
run_state = self.get_run_state()
match run_state:
case "success":
self._render_success_output()
case "error":
self._render_error_output()
case "waiting":
self._render_waiting_output()
case "recipe_root":
pass

# render outputs
self.render_output()

if not run_status:
if run_state != "waiting":
self._render_after_output()

def _render_success_output(self):
run_time = st.session_state.get(StateKeys.run_time, 0)
st.success(f"Success! Run Time: `{run_time:.2f}` seconds.")

def _render_error_output(self):
err_msg = st.session_state.get(StateKeys.error_msg)
st.error(err_msg)

def _render_waiting_output(self):
run_status = st.session_state.get(StateKeys.run_status)
st.caption("Your changes are saved in the above URL. Save it for later!")
html_spinner(run_status)
self.render_extra_waiting_output()

def render_extra_waiting_output(self):
pass

def on_submit(self):
example_id, run_id, uid = self.create_new_run()
if settings.CREDITS_TO_DEDUCT_PER_RUN and not self.check_credits():
Expand Down Expand Up @@ -1150,6 +1182,17 @@ def is_current_user_admin(self) -> bool:
def is_current_user_paying(self) -> bool:
return bool(self.request and self.request.user and self.request.user.is_paying)

def is_current_user_owner(self) -> bool:
"""
Did the current user create this run?
"""
return bool(
self.request
and self.request.user
and self.run_user
and self.request.user.uid == self.run_user.uid
)


def get_example_request_body(
request_model: typing.Type[BaseModel],
Expand Down
12 changes: 12 additions & 0 deletions gooey_ui/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import math
import textwrap
import typing
from datetime import datetime, timezone

import numpy as np

Expand Down Expand Up @@ -32,6 +33,17 @@ def dummy(*args, **kwargs):
dataframe = dummy


def countdown_timer(
end_time: datetime,
delay_text: str,
) -> state.NestingCtx:
return _node(
"countdown-timer",
endTime=end_time.astimezone(timezone.utc).isoformat(),
delayText=delay_text,
)


def nav_tabs():
return _node("nav-tabs")

Expand Down
22 changes: 22 additions & 0 deletions recipes/DeforumSD.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing
import uuid
from datetime import datetime, timedelta

from django.db.models import TextChoices
from pydantic import BaseModel
Expand All @@ -12,6 +13,7 @@
from daras_ai_v2.gpu_server import call_celery_task_outfile
from daras_ai_v2.loom_video_widget import youtube_video
from daras_ai_v2.safety_checker import safety_checker
from daras_ai_v2.tabs_widget import MenuTabs


class AnimationModels(TextChoices):
Expand All @@ -27,6 +29,7 @@ class _AnimationPrompt(TypedDict):
AnimationPrompts = list[_AnimationPrompt]

CREDITS_PER_FRAME = 1.5
MODEL_ESTIMATED_TIME_PER_FRAME = 2.4 # seconds


def input_prompt_to_animation_prompts(input_prompt: str):
Expand Down Expand Up @@ -417,6 +420,25 @@ def render_output(self):
st.write("Output Video")
st.video(output_video, autoplay=True)

def render_extra_waiting_output(self):
if created_at := st.session_state.get("created_at"):
start_time = datetime.fromisoformat(created_at)
with st.countdown_timer(
end_time=start_time + timedelta(seconds=self.estimate_run_duration()),
delay_text="Sorry for the wait. Your run is taking longer than we expected.",
):
if self.is_current_user_owner() and self.request.user.email:
st.write(
f"""We'll email **{self.request.user.email}** when your workflow is done."""
)
st.write(
f"""In the meantime, check out [🚀 Examples]({self.get_tab_url(MenuTabs.examples)}) for inspiration."""
)

def estimate_run_duration(self):
# in seconds
return st.session_state.get("max_frames", 100) * MODEL_ESTIMATED_TIME_PER_FRAME

def render_example(self, state: dict):
display = self.preview_input(state)
st.markdown("```lua\n" + display + "\n```")
Expand Down

0 comments on commit 0ecbc0e

Please sign in to comment.