Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TTS to work with user-owned elevenlabs API key #417

Merged
merged 3 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion celeryapp/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def runner_task(
run_id: str,
uid: str,
channel: str,
unsaved_state: dict[str, typing.Any] = None,
) -> int:
start_time = time()
error_msg = None
Expand Down Expand Up @@ -84,7 +85,7 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False
page = page_cls(request=SimpleNamespace(user=user))
page.setup_sentry()
sr = page.run_doc_sr(run_id, uid)
st.set_session_state(sr.to_dict())
st.set_session_state(sr.to_dict() | (unsaved_state or {}))
set_query_params(dict(run_id=run_id, uid=uid))

try:
Expand Down
17 changes: 15 additions & 2 deletions daras_ai_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1699,6 +1699,7 @@ def call_runner_task(self, sr: SavedRun):
run_id=sr.run_id,
uid=sr.uid,
channel=self.realtime_channel_name(sr.run_id, sr.uid),
unsaved_state=self._unsaved_state(),
)
| post_runner_tasks.s()
)
Expand Down Expand Up @@ -1782,7 +1783,7 @@ def load_state_defaults(cls, state: dict):
state.setdefault(k, v)
return state

def fields_to_save(self) -> [str]:
def fields_to_save(self) -> list[str]:
# only save the fields in request/response
return [
field_name
Expand All @@ -1794,6 +1795,18 @@ def fields_to_save(self) -> [str]:
StateKeys.run_time,
]

def _unsaved_state(self) -> dict[str, typing.Any]:
result = {}
for field in self.fields_not_to_save():
try:
result[field] = st.session_state[field]
except KeyError:
pass
return result

def fields_not_to_save(self) -> list[str]:
return []

def _examples_tab(self):
allow_hide = self.is_current_user_admin()

Expand Down Expand Up @@ -2059,7 +2072,7 @@ def run_as_api_tab(self):

api_example_generator(
api_url=api_url,
request_body=request_body,
request_body=request_body | self._unsaved_state(),
as_form_data=as_form_data,
as_async=as_async,
)
Expand Down
38 changes: 22 additions & 16 deletions daras_ai_v2/text_to_speech_settings_widgets.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from enum import Enum
import typing
from enum import Enum

import requests
from furl import furl
from daras_ai_v2.azure_asr import azure_auth_header

import gooey_ui as st
from daras_ai_v2 import settings
from daras_ai_v2.azure_asr import azure_auth_header
from daras_ai_v2.custom_enum import GooeyEnum
from daras_ai_v2.enum_selector_widget import enum_selector
from daras_ai_v2.exceptions import raise_for_status
from daras_ai_v2.redis_cache import redis_cache_decorator

if typing.TYPE_CHECKING:
from daras_ai_v2.base import BasePage

SESSION_ELEVENLABS_API_KEY = "__user__elevenlabs_api_key"

UBERDUCK_VOICES = {
Expand Down Expand Up @@ -329,20 +332,8 @@ def uberduck_settings():
)


def elevenlabs_selector(page):
if not st.session_state.get("elevenlabs_api_key"):
st.session_state["elevenlabs_api_key"] = page.request.session.get(
SESSION_ELEVENLABS_API_KEY
)

# for backwards compat
if old_voice_name := st.session_state.pop("elevenlabs_voice_name", None):
try:
st.session_state["elevenlabs_voice_id"] = OLD_ELEVEN_LABS_VOICES[
old_voice_name
]
except KeyError:
pass
def elevenlabs_selector(page: "BasePage"):
elevenlabs_init_state(page)

elevenlabs_use_custom_key = st.checkbox(
"Use custom API key + Voice ID",
Expand Down Expand Up @@ -406,6 +397,21 @@ def elevenlabs_selector(page):
)


def elevenlabs_init_state(page: "BasePage"):
if not st.session_state.get("elevenlabs_api_key"):
st.session_state["elevenlabs_api_key"] = page.request.session.get(
SESSION_ELEVENLABS_API_KEY
)
# for backwards compat
if old_voice_name := st.session_state.pop("elevenlabs_voice_name", None):
try:
st.session_state["elevenlabs_voice_id"] = OLD_ELEVEN_LABS_VOICES[
old_voice_name
]
except KeyError:
pass


def elevenlabs_settings():
col1, col2 = st.columns(2)
with col1:
Expand Down
22 changes: 16 additions & 6 deletions recipes/TextToSpeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
OpenAI_TTS_Models,
OpenAI_TTS_Voices,
OLD_ELEVEN_LABS_VOICES,
elevenlabs_init_state,
)

DEFAULT_TTS_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/a73181ce-9457-11ee-8edd-02420a0001c7/Voice%20generators.jpg.png"
Expand Down Expand Up @@ -97,6 +98,21 @@ def fallback_preivew_image(self) -> str | None:
def get_example_preferred_fields(cls, state: dict) -> list[str]:
return ["tts_provider"]

def fields_not_to_save(self):
return ["elevenlabs_api_key"]

def fields_to_save(self):
fields = super().fields_to_save()
try:
fields.remove("elevenlabs_api_key")
except ValueError:
pass
return fields

def run_as_api_tab(self):
elevenlabs_init_state(self)
super().run_as_api_tab()

def preview_description(self, state: dict) -> str:
return "Input your text, pick a voice & a Text-to-Speech AI engine to create audio. Compare the best voice generators from Google, UberDuck.ai & more to add automated voices to your podcast, YouTube videos, website, or app."

Expand All @@ -123,12 +139,6 @@ def render_form_v2(self):
)
text_to_speech_provider_selector(self)

def fields_to_save(self):
fields = super().fields_to_save()
if "elevenlabs_api_key" in fields:
fields.remove("elevenlabs_api_key")
return fields

def validate_form_v2(self):
assert st.session_state.get("text_prompt"), "Text input cannot be empty"
assert st.session_state.get("tts_provider"), "Please select a TTS provider"
Expand Down
14 changes: 12 additions & 2 deletions recipes/VideoBots.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
TextToSpeechProviders,
text_to_speech_settings,
text_to_speech_provider_selector,
elevenlabs_init_state,
)
from daras_ai_v2.vector_search import DocSearchRequest
from functions.recipe_functions import LLMTools
Expand Down Expand Up @@ -545,12 +546,21 @@ def render_settings(self):
key="tools",
)

def fields_not_to_save(self):
return ["elevenlabs_api_key"]

def fields_to_save(self) -> [str]:
fields = super().fields_to_save() + ["landbot_url"]
if "elevenlabs_api_key" in fields:
fields = super().fields_to_save()
try:
fields.remove("elevenlabs_api_key")
except ValueError:
pass
return fields

def run_as_api_tab(self):
elevenlabs_init_state(self)
super().run_as_api_tab()

def render_example(self, state: dict):
input_prompt = state.get("input_prompt")
if input_prompt:
Expand Down