diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index 2e30e4379..29e480602 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -41,6 +41,7 @@ def runner_task( run_id: str, uid: str, channel: str, + unsaved_state: dict[str, typing.Any] = None, ) -> int: start_time = time() error_msg = None @@ -84,7 +85,7 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False page = page_cls(request=SimpleNamespace(user=user)) page.setup_sentry() sr = page.run_doc_sr(run_id, uid) - st.set_session_state(sr.to_dict()) + st.set_session_state(sr.to_dict() | (unsaved_state or {})) set_query_params(dict(run_id=run_id, uid=uid)) try: diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 5be61de83..472b7ffae 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -1699,6 +1699,7 @@ def call_runner_task(self, sr: SavedRun): run_id=sr.run_id, uid=sr.uid, channel=self.realtime_channel_name(sr.run_id, sr.uid), + unsaved_state=self._unsaved_state(), ) | post_runner_tasks.s() ) @@ -1782,7 +1783,7 @@ def load_state_defaults(cls, state: dict): state.setdefault(k, v) return state - def fields_to_save(self) -> [str]: + def fields_to_save(self) -> list[str]: # only save the fields in request/response return [ field_name @@ -1794,6 +1795,18 @@ def fields_to_save(self) -> [str]: StateKeys.run_time, ] + def _unsaved_state(self) -> dict[str, typing.Any]: + result = {} + for field in self.fields_not_to_save(): + try: + result[field] = st.session_state[field] + except KeyError: + pass + return result + + def fields_not_to_save(self) -> list[str]: + return [] + def _examples_tab(self): allow_hide = self.is_current_user_admin() @@ -2059,7 +2072,7 @@ def run_as_api_tab(self): api_example_generator( api_url=api_url, - request_body=request_body, + request_body=request_body | self._unsaved_state(), as_form_data=as_form_data, as_async=as_async, ) diff --git a/daras_ai_v2/text_to_speech_settings_widgets.py b/daras_ai_v2/text_to_speech_settings_widgets.py index 2fb29eba0..44e2d0976 100644 --- a/daras_ai_v2/text_to_speech_settings_widgets.py +++ b/daras_ai_v2/text_to_speech_settings_widgets.py @@ -1,17 +1,20 @@ -from enum import Enum import typing +from enum import Enum import requests from furl import furl -from daras_ai_v2.azure_asr import azure_auth_header import gooey_ui as st from daras_ai_v2 import settings +from daras_ai_v2.azure_asr import azure_auth_header from daras_ai_v2.custom_enum import GooeyEnum from daras_ai_v2.enum_selector_widget import enum_selector from daras_ai_v2.exceptions import raise_for_status from daras_ai_v2.redis_cache import redis_cache_decorator +if typing.TYPE_CHECKING: + from daras_ai_v2.base import BasePage + SESSION_ELEVENLABS_API_KEY = "__user__elevenlabs_api_key" UBERDUCK_VOICES = { @@ -329,20 +332,8 @@ def uberduck_settings(): ) -def elevenlabs_selector(page): - if not st.session_state.get("elevenlabs_api_key"): - st.session_state["elevenlabs_api_key"] = page.request.session.get( - SESSION_ELEVENLABS_API_KEY - ) - - # for backwards compat - if old_voice_name := st.session_state.pop("elevenlabs_voice_name", None): - try: - st.session_state["elevenlabs_voice_id"] = OLD_ELEVEN_LABS_VOICES[ - old_voice_name - ] - except KeyError: - pass +def elevenlabs_selector(page: "BasePage"): + elevenlabs_init_state(page) elevenlabs_use_custom_key = st.checkbox( "Use custom API key + Voice ID", @@ -406,6 +397,21 @@ def elevenlabs_selector(page): ) +def elevenlabs_init_state(page: "BasePage"): + if not st.session_state.get("elevenlabs_api_key"): + st.session_state["elevenlabs_api_key"] = page.request.session.get( + SESSION_ELEVENLABS_API_KEY + ) + # for backwards compat + if old_voice_name := st.session_state.pop("elevenlabs_voice_name", None): + try: + st.session_state["elevenlabs_voice_id"] = OLD_ELEVEN_LABS_VOICES[ + old_voice_name + ] + except KeyError: + pass + + def elevenlabs_settings(): col1, col2 = st.columns(2) with col1: diff --git a/recipes/TextToSpeech.py b/recipes/TextToSpeech.py index f1e5449ce..48a4baaf9 100644 --- a/recipes/TextToSpeech.py +++ b/recipes/TextToSpeech.py @@ -24,6 +24,7 @@ OpenAI_TTS_Models, OpenAI_TTS_Voices, OLD_ELEVEN_LABS_VOICES, + elevenlabs_init_state, ) DEFAULT_TTS_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/a73181ce-9457-11ee-8edd-02420a0001c7/Voice%20generators.jpg.png" @@ -97,6 +98,21 @@ def fallback_preivew_image(self) -> str | None: def get_example_preferred_fields(cls, state: dict) -> list[str]: return ["tts_provider"] + def fields_not_to_save(self): + return ["elevenlabs_api_key"] + + def fields_to_save(self): + fields = super().fields_to_save() + try: + fields.remove("elevenlabs_api_key") + except ValueError: + pass + return fields + + def run_as_api_tab(self): + elevenlabs_init_state(self) + super().run_as_api_tab() + def preview_description(self, state: dict) -> str: return "Input your text, pick a voice & a Text-to-Speech AI engine to create audio. Compare the best voice generators from Google, UberDuck.ai & more to add automated voices to your podcast, YouTube videos, website, or app." @@ -123,12 +139,6 @@ def render_form_v2(self): ) text_to_speech_provider_selector(self) - def fields_to_save(self): - fields = super().fields_to_save() - if "elevenlabs_api_key" in fields: - fields.remove("elevenlabs_api_key") - return fields - def validate_form_v2(self): assert st.session_state.get("text_prompt"), "Text input cannot be empty" assert st.session_state.get("tts_provider"), "Please select a TTS provider" diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 3c9f5772e..b3218d9ed 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -90,6 +90,7 @@ TextToSpeechProviders, text_to_speech_settings, text_to_speech_provider_selector, + elevenlabs_init_state, ) from daras_ai_v2.vector_search import DocSearchRequest from functions.recipe_functions import LLMTools @@ -545,12 +546,21 @@ def render_settings(self): key="tools", ) + def fields_not_to_save(self): + return ["elevenlabs_api_key"] + def fields_to_save(self) -> [str]: - fields = super().fields_to_save() + ["landbot_url"] - if "elevenlabs_api_key" in fields: + fields = super().fields_to_save() + try: fields.remove("elevenlabs_api_key") + except ValueError: + pass return fields + def run_as_api_tab(self): + elevenlabs_init_state(self) + super().run_as_api_tab() + def render_example(self, state: dict): input_prompt = state.get("input_prompt") if input_prompt: