diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index f8899b172..4c8af2d8f 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -91,6 +91,7 @@ class BasePage: slug_versions: list[str] sane_defaults: dict = {} + RequestModel: typing.Type[BaseModel] ResponseModel: typing.Type[BaseModel] @@ -154,7 +155,6 @@ def render(self): ) example_id, run_id, uid = extract_query_params(gooey_get_query_params()) - if st.session_state.get(StateKeys.run_status): channel = f"gooey-outputs/{self.slug_versions[0]}/{uid}/{run_id}" output = realtime_pull([channel])[0] @@ -1150,6 +1150,11 @@ def is_current_user_admin(self) -> bool: def is_current_user_paying(self) -> bool: return bool(self.request and self.request.user and self.request.user.is_paying) + def is_current_user_owner(self) -> bool: + return bool( + self.request and self.request.user and self.run_user == self.request.user + ) + def get_example_request_body( request_model: typing.Type[BaseModel], diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index d12abcf44..6d75b7396 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -387,8 +387,6 @@ def _run_chat_model( case LLMApis.openai: from openai import OpenAI - print([len(messages), max_tokens, num_outputs, messages]) - client = OpenAI() r = client.chat.completions.create( model=engine, diff --git a/daras_ai_v2/text_to_speech_settings_widgets.py b/daras_ai_v2/text_to_speech_settings_widgets.py index 137cdb3c9..110edc279 100644 --- a/daras_ai_v2/text_to_speech_settings_widgets.py +++ b/daras_ai_v2/text_to_speech_settings_widgets.py @@ -1,11 +1,14 @@ from enum import Enum -import gooey_ui as st +import requests from google.cloud import texttospeech +import gooey_ui as st from daras_ai_v2.enum_selector_widget import enum_selector from daras_ai_v2.redis_cache import redis_cache_decorator +SESSION_ELEVENLABS_API_KEY = "__user__elevenlabs_api_key" + UBERDUCK_VOICES = { "Aiden Botha": "b01cf18d-0f10-46dd-adc6-562b599fdae4", "Angus": "7d29a280-8a3e-4c4b-9df4-cbe77e8f4a63", @@ -23,7 +26,7 @@ class TextToSpeechProviders(Enum): GOOGLE_TTS = "Google Cloud Text-to-Speech" - ELEVEN_LABS = "Eleven Labs (Premium)" + ELEVEN_LABS = "Eleven Labs" UBERDUCK = "uberduck.ai" BARK = "Bark (suno-ai)" @@ -135,7 +138,7 @@ class TextToSpeechProviders(Enum): } -def text_to_speech_settings(page=None): +def text_to_speech_settings(page): st.write( """ ##### 🗣️ Voice Settings @@ -227,25 +230,86 @@ def text_to_speech_settings(page=None): case TextToSpeechProviders.ELEVEN_LABS.name: with col2: - if not ( - page - and (page.is_current_user_paying() or page.is_current_user_admin()) - ): - st.caption( + if not st.session_state.get("elevenlabs_api_key"): + st.session_state["elevenlabs_api_key"] = page.request.session.get( + SESSION_ELEVENLABS_API_KEY + ) + + elevenlabs_use_custom_key = st.checkbox( + "Use custom API key + Voice ID", + value=bool(st.session_state.get("elevenlabs_api_key")), + ) + if elevenlabs_use_custom_key: + st.session_state["elevenlabs_voice_name"] = None + elevenlabs_api_key = st.text_input( """ - Note: Please purchase Gooey.AI credits to use ElevenLabs voices - here. + ###### Your ElevenLabs API key + *Read this + to know how to obtain an API key from + ElevenLabs.* + """, + key="elevenlabs_api_key", + ) + + selected_voice_id = st.session_state.get("elevenlabs_voice_id") + elevenlabs_voices = ( + {selected_voice_id: selected_voice_id} + if selected_voice_id + else {} + ) + + if elevenlabs_api_key: + try: + elevenlabs_voices = fetch_elevenlabs_voices( + elevenlabs_api_key + ) + except requests.exceptions.HTTPError as e: + st.error( + f"Invalid ElevenLabs API key. Failed to fetch voices: {e}" + ) + + st.selectbox( + """ + ###### Voice ID (ElevenLabs) + """, + key="elevenlabs_voice_id", + options=elevenlabs_voices.keys(), + format_func=elevenlabs_voices.__getitem__, + ) + else: + st.session_state["elevenlabs_api_key"] = None + st.session_state["elevenlabs_voice_id"] = None + if not ( + page + and ( + page.is_current_user_paying() + or page.is_current_user_admin() + ) + ): + st.caption( + """ + Note: Please purchase Gooey.AI credits to use ElevenLabs voices + here.
+ Alternatively, you can use your own ElevenLabs API key by selecting the checkbox above. + """ + ) + + st.session_state.update( + elevenlabs_api_key=None, elevenlabs_voice_id=None + ) + st.selectbox( """ + ###### Voice Name (ElevenLabs) + """, + key="elevenlabs_voice_name", + format_func=str, + options=ELEVEN_LABS_VOICES.keys(), ) - st.selectbox( - """ - ###### Voice name (ElevenLabs) - """, - key="elevenlabs_voice_name", - format_func=str, - options=ELEVEN_LABS_VOICES.keys(), + page.request.session[SESSION_ELEVENLABS_API_KEY] = st.session_state.get( + "elevenlabs_api_key" ) + st.selectbox( """ ###### Voice Model @@ -323,3 +387,28 @@ def _voice_sort_key(voice: texttospeech.Voice): # sort alphabetically voice.name, ) + + +_elevenlabs_category_order = { + "cloned": 1, + "generated": 2, + "premade": 3, +} + + +@st.cache_in_session_state +def fetch_elevenlabs_voices(api_key: str) -> dict[str, str]: + r = requests.get( + "https://api.elevenlabs.io/v1/voices", + headers={"Accept": "application/json", "xi-api-key": api_key}, + ) + r.raise_for_status() + print(r.json()["voices"]) + sorted_voices = sorted( + r.json()["voices"], + key=lambda v: (_elevenlabs_category_order.get(v["category"], 0), v["name"]), + ) + return { + v["voice_id"]: " - ".join([v["name"], *v["labels"].values()]) + for v in sorted_voices + } diff --git a/gooey_ui/components.py b/gooey_ui/components.py index b6919a487..245c9e97e 100644 --- a/gooey_ui/components.py +++ b/gooey_ui/components.py @@ -619,6 +619,33 @@ def text_input( return value or "" +def password_input( + label: str, + value: str = "", + max_chars: str = None, + key: str = None, + help: str = None, + *, + placeholder: str = None, + disabled: bool = False, + label_visibility: LabelVisibility = "visible", + **props, +) -> str: + value = _input_widget( + input_type="password", + label=label, + value=value, + key=key, + help=help, + disabled=disabled, + label_visibility=label_visibility, + maxLength=max_chars, + placeholder=placeholder, + **props, + ) + return value or "" + + def slider( label: str, min_value: float = None, diff --git a/recipes/LipsyncTTS.py b/recipes/LipsyncTTS.py index 279dd8b15..cc900c0c1 100644 --- a/recipes/LipsyncTTS.py +++ b/recipes/LipsyncTTS.py @@ -47,6 +47,8 @@ class RequestModel(BaseModel): bark_history_prompt: str | None elevenlabs_voice_name: str | None + elevenlabs_api_key: str | None + elevenlabs_voice_id: str | None elevenlabs_model: str | None elevenlabs_stability: float | None elevenlabs_similarity_boost: float | None diff --git a/recipes/TextToSpeech.py b/recipes/TextToSpeech.py index de8178cee..dfa5da82b 100644 --- a/recipes/TextToSpeech.py +++ b/recipes/TextToSpeech.py @@ -65,6 +65,8 @@ class RequestModel(BaseModel): bark_history_prompt: str | None elevenlabs_voice_name: str | None + elevenlabs_api_key: str | None + elevenlabs_voice_id: str | None elevenlabs_model: str | None elevenlabs_stability: float | None elevenlabs_similarity_boost: float | None @@ -100,6 +102,12 @@ def render_form_v2(self): key="text_prompt", ) + def fields_to_save(self): + fields = super().fields_to_save() + if "elevenlabs_api_key" in fields: + fields.remove("elevenlabs_api_key") + return fields + def validate_form_v2(self): assert st.session_state["text_prompt"], "Text input cannot be empty" @@ -110,7 +118,7 @@ def get_raw_price(self, state: dict): tts_provider = self._get_tts_provider(state) match tts_provider: case TextToSpeechProviders.ELEVEN_LABS: - return self._get_eleven_labs_price(state) + return self._get_elevenlabs_price(state) case _: return super().get_raw_price(state) @@ -126,10 +134,14 @@ def render_output(self): else: st.div() - def _get_eleven_labs_price(self, state: dict): - text = state.get("text_prompt", "") - # 0.079 credits / character ~ 4 credits / 10 words - return len(text) * 0.079 + def _get_elevenlabs_price(self, state: dict): + _, is_user_provided_key = self._get_elevenlabs_api_key(state) + if is_user_provided_key: + return 0 + else: + text = state.get("text_prompt", "") + # 0.079 credits / character ~ 4 credits / 10 words + return len(text) * 0.079 def _get_tts_provider(self, state: dict): tts_provider = state.get("tts_provider", TextToSpeechProviders.UBERDUCK.name) @@ -139,9 +151,11 @@ def _get_tts_provider(self, state: dict): def additional_notes(self): tts_provider = st.session_state.get("tts_provider") if tts_provider == TextToSpeechProviders.ELEVEN_LABS.name: - return """ - *Eleven Labs cost ≈ 4 credits per 10 words* - """ + _, is_user_provided_key = self._get_elevenlabs_api_key(st.session_state) + if is_user_provided_key: + return "*Eleven Labs cost ≈ No additional credit charge given we'll use your API key*" + else: + return "*Eleven Labs cost ≈ 4 credits per 10 words*" else: return "" @@ -239,26 +253,17 @@ def run(self, state: dict): ) case TextToSpeechProviders.ELEVEN_LABS: + xi_api_key, is_custom_key = self._get_elevenlabs_api_key(state) assert ( - self.is_current_user_paying() or self.is_current_user_admin() + is_custom_key + or self.is_current_user_paying() + or self.is_current_user_admin() ), """ Please purchase Gooey.AI credits to use ElevenLabs voices here. """ - # default to first in the mapping - default_voice_model = next(iter(ELEVEN_LABS_MODELS)) - default_voice_name = next(iter(ELEVEN_LABS_VOICES)) - - voice_model = state.get("elevenlabs_model", default_voice_model) - voice_name = state.get("elevenlabs_voice_name", default_voice_name) - - # validate voice_model / voice_name - if voice_model not in ELEVEN_LABS_MODELS: - raise ValueError(f"Invalid model: {voice_model}") - if voice_name not in ELEVEN_LABS_VOICES: - raise ValueError(f"Invalid voice_name: {voice_name}") - else: - voice_id = ELEVEN_LABS_VOICES[voice_name] + voice_model = self._get_elevenlabs_voice_model(state) + voice_id = self._get_elevenlabs_voice_id(state) stability = state.get("elevenlabs_stability", 0.5) similarity_boost = state.get("elevenlabs_similarity_boost", 0.75) @@ -266,7 +271,7 @@ def run(self, state: dict): response = requests.post( f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}", headers={ - "xi-api-key": settings.ELEVEN_LABS_API_KEY, + "xi-api-key": xi_api_key, "Accept": "audio/mpeg", }, json={ @@ -285,6 +290,35 @@ def run(self, state: dict): "elevenlabs_gen.mp3", response.content ) + def _get_elevenlabs_voice_model(self, state: dict[str, str]): + default_voice_model = next(iter(ELEVEN_LABS_MODELS)) + voice_model = state.get("elevenlabs_model", default_voice_model) + assert voice_model in ELEVEN_LABS_MODELS, f"Invalid model: {voice_model}" + return voice_model + + def _get_elevenlabs_voice_id(self, state: dict[str, str]): + if state.get("elevenlabs_voice_id"): + assert state.get( + "elevenlabs_api_key" + ), "ElevenLabs API key is required to use a custom voice_id" + return state["elevenlabs_voice_id"] + else: + # default to first in the mapping + default_voice_name = next(iter(ELEVEN_LABS_VOICES)) + voice_name = state.get("elevenlabs_voice_name", default_voice_name) + assert voice_name in ELEVEN_LABS_VOICES, f"Invalid voice_name: {voice_name}" + return ELEVEN_LABS_VOICES[voice_name] # voice_name -> voice_id + + def _get_elevenlabs_api_key(self, state: dict[str, str]) -> tuple[str, bool]: + """ + Returns the 11labs API key and whether it is a user-provided key or the default key + """ + # ElevenLabs is available for non-paying users with their own API key + if state.get("elevenlabs_api_key"): + return state["elevenlabs_api_key"], True + else: + return settings.ELEVEN_LABS_API_KEY, False + def related_workflows(self) -> list: from recipes.VideoBots import VideoBotsPage from recipes.LipsyncTTS import LipsyncTTSPage diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 5857b4ea3..898bd3981 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -205,6 +205,8 @@ class RequestModel(BaseModel): google_pitch: float | None bark_history_prompt: str | None elevenlabs_voice_name: str | None + elevenlabs_api_key: str | None + elevenlabs_voice_id: str | None elevenlabs_model: str | None elevenlabs_stability: float | None elevenlabs_similarity_boost: float | None @@ -430,7 +432,10 @@ def render_settings(self): lipsync_settings() def fields_to_save(self) -> [str]: - return super().fields_to_save() + ["landbot_url"] + fields = super().fields_to_save() + ["landbot_url"] + if "elevenlabs_api_key" in fields: + fields.remove("elevenlabs_api_key") + return fields def render_example(self, state: dict): input_prompt = state.get("input_prompt") @@ -607,7 +612,7 @@ def additional_notes(self): case TextToSpeechProviders.ELEVEN_LABS.name: return f""" - *Base cost = {super().get_raw_price(st.session_state)} credits* - - *Additional Eleven Labs cost ≈ 4 credits per 10 words of the output* + - *Additional {TextToSpeechPage().additional_notes()}* """ case _: return ""