Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for custom ElevenLabs API key #202

Merged
merged 12 commits into from
Nov 10, 2023
7 changes: 6 additions & 1 deletion daras_ai_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class BasePage:
slug_versions: list[str]

sane_defaults: dict = {}

RequestModel: typing.Type[BaseModel]
ResponseModel: typing.Type[BaseModel]

Expand Down Expand Up @@ -154,7 +155,6 @@ def render(self):
)

example_id, run_id, uid = extract_query_params(gooey_get_query_params())

if st.session_state.get(StateKeys.run_status):
channel = f"gooey-outputs/{self.slug_versions[0]}/{uid}/{run_id}"
output = realtime_pull([channel])[0]
Expand Down Expand Up @@ -1150,6 +1150,11 @@ def is_current_user_admin(self) -> bool:
def is_current_user_paying(self) -> bool:
return bool(self.request and self.request.user and self.request.user.is_paying)

def is_current_user_owner(self) -> bool:
return bool(
self.request and self.request.user and self.run_user == self.request.user
)


def get_example_request_body(
request_model: typing.Type[BaseModel],
Expand Down
2 changes: 0 additions & 2 deletions daras_ai_v2/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,6 @@ def _run_chat_model(
case LLMApis.openai:
from openai import OpenAI

print([len(messages), max_tokens, num_outputs, messages])

client = OpenAI()
r = client.chat.completions.create(
model=engine,
Expand Down
123 changes: 106 additions & 17 deletions daras_ai_v2/text_to_speech_settings_widgets.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from enum import Enum

import gooey_ui as st
import requests
from google.cloud import texttospeech

import gooey_ui as st
from daras_ai_v2.enum_selector_widget import enum_selector
from daras_ai_v2.redis_cache import redis_cache_decorator

SESSION_ELEVENLABS_API_KEY = "__user__elevenlabs_api_key"

UBERDUCK_VOICES = {
"Aiden Botha": "b01cf18d-0f10-46dd-adc6-562b599fdae4",
"Angus": "7d29a280-8a3e-4c4b-9df4-cbe77e8f4a63",
Expand All @@ -23,7 +26,7 @@

class TextToSpeechProviders(Enum):
GOOGLE_TTS = "Google Cloud Text-to-Speech"
ELEVEN_LABS = "Eleven Labs (Premium)"
ELEVEN_LABS = "Eleven Labs"
UBERDUCK = "uberduck.ai"
BARK = "Bark (suno-ai)"

Expand Down Expand Up @@ -135,7 +138,7 @@ class TextToSpeechProviders(Enum):
}


def text_to_speech_settings(page=None):
def text_to_speech_settings(page):
st.write(
"""
##### 🗣️ Voice Settings
Expand Down Expand Up @@ -227,25 +230,86 @@ def text_to_speech_settings(page=None):

case TextToSpeechProviders.ELEVEN_LABS.name:
with col2:
if not (
page
and (page.is_current_user_paying() or page.is_current_user_admin())
):
st.caption(
if not st.session_state.get("elevenlabs_api_key"):
st.session_state["elevenlabs_api_key"] = page.request.session.get(
SESSION_ELEVENLABS_API_KEY
)

elevenlabs_use_custom_key = st.checkbox(
"Use custom API key + Voice ID",
value=bool(st.session_state.get("elevenlabs_api_key")),
)
if elevenlabs_use_custom_key:
st.session_state["elevenlabs_voice_name"] = None
elevenlabs_api_key = st.text_input(
"""
Note: Please purchase Gooey.AI credits to use ElevenLabs voices
<a href="/account">here</a>.
###### Your ElevenLabs API key
*Read <a target="_blank" href="https://docs.elevenlabs.io/api-reference/authentication">this</a>
to know how to obtain an API key from
ElevenLabs.*
""",
key="elevenlabs_api_key",
)

selected_voice_id = st.session_state.get("elevenlabs_voice_id")
elevenlabs_voices = (
{selected_voice_id: selected_voice_id}
if selected_voice_id
else {}
)

if elevenlabs_api_key:
try:
elevenlabs_voices = fetch_elevenlabs_voices(
elevenlabs_api_key
)
except requests.exceptions.HTTPError as e:
st.error(
f"Invalid ElevenLabs API key. Failed to fetch voices: {e}"
)

st.selectbox(
"""
###### Voice ID (ElevenLabs)
""",
key="elevenlabs_voice_id",
options=elevenlabs_voices.keys(),
format_func=elevenlabs_voices.__getitem__,
)
else:
st.session_state["elevenlabs_api_key"] = None
st.session_state["elevenlabs_voice_id"] = None
if not (
page
and (
page.is_current_user_paying()
or page.is_current_user_admin()
)
):
st.caption(
"""
Note: Please purchase Gooey.AI credits to use ElevenLabs voices
<a href="/account">here</a>.<br/>
Alternatively, you can use your own ElevenLabs API key by selecting the checkbox above.
"""
)

st.session_state.update(
elevenlabs_api_key=None, elevenlabs_voice_id=None
)
st.selectbox(
"""
###### Voice Name (ElevenLabs)
""",
key="elevenlabs_voice_name",
format_func=str,
options=ELEVEN_LABS_VOICES.keys(),
)

st.selectbox(
"""
###### Voice name (ElevenLabs)
""",
key="elevenlabs_voice_name",
format_func=str,
options=ELEVEN_LABS_VOICES.keys(),
page.request.session[SESSION_ELEVENLABS_API_KEY] = st.session_state.get(
"elevenlabs_api_key"
)

st.selectbox(
"""
###### Voice Model
Expand Down Expand Up @@ -323,3 +387,28 @@ def _voice_sort_key(voice: texttospeech.Voice):
# sort alphabetically
voice.name,
)


_elevenlabs_category_order = {
"cloned": 1,
"generated": 2,
"premade": 3,
}


@st.cache_in_session_state
def fetch_elevenlabs_voices(api_key: str) -> dict[str, str]:
r = requests.get(
"https://api.elevenlabs.io/v1/voices",
headers={"Accept": "application/json", "xi-api-key": api_key},
)
r.raise_for_status()
print(r.json()["voices"])
sorted_voices = sorted(
r.json()["voices"],
key=lambda v: (_elevenlabs_category_order.get(v["category"], 0), v["name"]),
)
return {
v["voice_id"]: " - ".join([v["name"], *v["labels"].values()])
for v in sorted_voices
}
27 changes: 27 additions & 0 deletions gooey_ui/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,33 @@ def text_input(
return value or ""


def password_input(
label: str,
value: str = "",
max_chars: str = None,
key: str = None,
help: str = None,
*,
placeholder: str = None,
disabled: bool = False,
label_visibility: LabelVisibility = "visible",
**props,
) -> str:
value = _input_widget(
input_type="password",
label=label,
value=value,
key=key,
help=help,
disabled=disabled,
label_visibility=label_visibility,
maxLength=max_chars,
placeholder=placeholder,
**props,
)
return value or ""


def slider(
label: str,
min_value: float = None,
Expand Down
2 changes: 2 additions & 0 deletions recipes/LipsyncTTS.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class RequestModel(BaseModel):
bark_history_prompt: str | None

elevenlabs_voice_name: str | None
elevenlabs_api_key: str | None
elevenlabs_voice_id: str | None
elevenlabs_model: str | None
elevenlabs_stability: float | None
elevenlabs_similarity_boost: float | None
Expand Down
82 changes: 58 additions & 24 deletions recipes/TextToSpeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ class RequestModel(BaseModel):
bark_history_prompt: str | None

elevenlabs_voice_name: str | None
elevenlabs_api_key: str | None
elevenlabs_voice_id: str | None
elevenlabs_model: str | None
elevenlabs_stability: float | None
elevenlabs_similarity_boost: float | None
Expand Down Expand Up @@ -100,6 +102,12 @@ def render_form_v2(self):
key="text_prompt",
)

def fields_to_save(self):
fields = super().fields_to_save()
if "elevenlabs_api_key" in fields:
fields.remove("elevenlabs_api_key")
return fields

def validate_form_v2(self):
assert st.session_state["text_prompt"], "Text input cannot be empty"

Expand All @@ -110,7 +118,7 @@ def get_raw_price(self, state: dict):
tts_provider = self._get_tts_provider(state)
match tts_provider:
case TextToSpeechProviders.ELEVEN_LABS:
return self._get_eleven_labs_price(state)
return self._get_elevenlabs_price(state)
case _:
return super().get_raw_price(state)

Expand All @@ -126,10 +134,14 @@ def render_output(self):
else:
st.div()

def _get_eleven_labs_price(self, state: dict):
text = state.get("text_prompt", "")
# 0.079 credits / character ~ 4 credits / 10 words
return len(text) * 0.079
def _get_elevenlabs_price(self, state: dict):
_, is_user_provided_key = self._get_elevenlabs_api_key(state)
if is_user_provided_key:
return 0
else:
text = state.get("text_prompt", "")
# 0.079 credits / character ~ 4 credits / 10 words
return len(text) * 0.079

def _get_tts_provider(self, state: dict):
tts_provider = state.get("tts_provider", TextToSpeechProviders.UBERDUCK.name)
Expand All @@ -139,9 +151,11 @@ def _get_tts_provider(self, state: dict):
def additional_notes(self):
tts_provider = st.session_state.get("tts_provider")
if tts_provider == TextToSpeechProviders.ELEVEN_LABS.name:
return """
*Eleven Labs cost ≈ 4 credits per 10 words*
"""
_, is_user_provided_key = self._get_elevenlabs_api_key(st.session_state)
if is_user_provided_key:
return "*Eleven Labs cost ≈ No additional credit charge given we'll use your API key*"
else:
return "*Eleven Labs cost ≈ 4 credits per 10 words*"
else:
return ""

Expand Down Expand Up @@ -239,34 +253,25 @@ def run(self, state: dict):
)

case TextToSpeechProviders.ELEVEN_LABS:
xi_api_key, is_custom_key = self._get_elevenlabs_api_key(state)
assert (
self.is_current_user_paying() or self.is_current_user_admin()
is_custom_key
or self.is_current_user_paying()
or self.is_current_user_admin()
), """
Please purchase Gooey.AI credits to use ElevenLabs voices <a href="/account">here</a>.
"""

# default to first in the mapping
default_voice_model = next(iter(ELEVEN_LABS_MODELS))
default_voice_name = next(iter(ELEVEN_LABS_VOICES))

voice_model = state.get("elevenlabs_model", default_voice_model)
voice_name = state.get("elevenlabs_voice_name", default_voice_name)

# validate voice_model / voice_name
if voice_model not in ELEVEN_LABS_MODELS:
raise ValueError(f"Invalid model: {voice_model}")
if voice_name not in ELEVEN_LABS_VOICES:
raise ValueError(f"Invalid voice_name: {voice_name}")
else:
voice_id = ELEVEN_LABS_VOICES[voice_name]
voice_model = self._get_elevenlabs_voice_model(state)
voice_id = self._get_elevenlabs_voice_id(state)

stability = state.get("elevenlabs_stability", 0.5)
similarity_boost = state.get("elevenlabs_similarity_boost", 0.75)

response = requests.post(
f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}",
headers={
"xi-api-key": settings.ELEVEN_LABS_API_KEY,
"xi-api-key": xi_api_key,
"Accept": "audio/mpeg",
},
json={
Expand All @@ -285,6 +290,35 @@ def run(self, state: dict):
"elevenlabs_gen.mp3", response.content
)

def _get_elevenlabs_voice_model(self, state: dict[str, str]):
default_voice_model = next(iter(ELEVEN_LABS_MODELS))
voice_model = state.get("elevenlabs_model", default_voice_model)
assert voice_model in ELEVEN_LABS_MODELS, f"Invalid model: {voice_model}"
return voice_model

def _get_elevenlabs_voice_id(self, state: dict[str, str]):
if state.get("elevenlabs_voice_id"):
assert state.get(
"elevenlabs_api_key"
), "ElevenLabs API key is required to use a custom voice_id"
return state["elevenlabs_voice_id"]
else:
# default to first in the mapping
default_voice_name = next(iter(ELEVEN_LABS_VOICES))
voice_name = state.get("elevenlabs_voice_name", default_voice_name)
assert voice_name in ELEVEN_LABS_VOICES, f"Invalid voice_name: {voice_name}"
return ELEVEN_LABS_VOICES[voice_name] # voice_name -> voice_id

def _get_elevenlabs_api_key(self, state: dict[str, str]) -> tuple[str, bool]:
"""
Returns the 11labs API key and whether it is a user-provided key or the default key
"""
# ElevenLabs is available for non-paying users with their own API key
if state.get("elevenlabs_api_key"):
return state["elevenlabs_api_key"], True
else:
return settings.ELEVEN_LABS_API_KEY, False

def related_workflows(self) -> list:
from recipes.VideoBots import VideoBotsPage
from recipes.LipsyncTTS import LipsyncTTSPage
Expand Down
Loading
Loading