Skip to content

Commit

Permalink
Merge pull request #202 from GooeyAI/11labs-key
Browse files Browse the repository at this point in the history
  • Loading branch information
devxpy authored Nov 10, 2023
2 parents 1cdbd3c + 31a4372 commit 7683f0e
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 46 deletions.
7 changes: 6 additions & 1 deletion daras_ai_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class BasePage:
slug_versions: list[str]

sane_defaults: dict = {}

RequestModel: typing.Type[BaseModel]
ResponseModel: typing.Type[BaseModel]

Expand Down Expand Up @@ -154,7 +155,6 @@ def render(self):
)

example_id, run_id, uid = extract_query_params(gooey_get_query_params())

if st.session_state.get(StateKeys.run_status):
channel = f"gooey-outputs/{self.slug_versions[0]}/{uid}/{run_id}"
output = realtime_pull([channel])[0]
Expand Down Expand Up @@ -1150,6 +1150,11 @@ def is_current_user_admin(self) -> bool:
def is_current_user_paying(self) -> bool:
return bool(self.request and self.request.user and self.request.user.is_paying)

def is_current_user_owner(self) -> bool:
return bool(
self.request and self.request.user and self.run_user == self.request.user
)


def get_example_request_body(
request_model: typing.Type[BaseModel],
Expand Down
2 changes: 0 additions & 2 deletions daras_ai_v2/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,6 @@ def _run_chat_model(
case LLMApis.openai:
from openai import OpenAI

print([len(messages), max_tokens, num_outputs, messages])

client = OpenAI()
r = client.chat.completions.create(
model=engine,
Expand Down
123 changes: 106 additions & 17 deletions daras_ai_v2/text_to_speech_settings_widgets.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from enum import Enum

import gooey_ui as st
import requests
from google.cloud import texttospeech

import gooey_ui as st
from daras_ai_v2.enum_selector_widget import enum_selector
from daras_ai_v2.redis_cache import redis_cache_decorator

SESSION_ELEVENLABS_API_KEY = "__user__elevenlabs_api_key"

UBERDUCK_VOICES = {
"Aiden Botha": "b01cf18d-0f10-46dd-adc6-562b599fdae4",
"Angus": "7d29a280-8a3e-4c4b-9df4-cbe77e8f4a63",
Expand All @@ -23,7 +26,7 @@

class TextToSpeechProviders(Enum):
GOOGLE_TTS = "Google Cloud Text-to-Speech"
ELEVEN_LABS = "Eleven Labs (Premium)"
ELEVEN_LABS = "Eleven Labs"
UBERDUCK = "uberduck.ai"
BARK = "Bark (suno-ai)"

Expand Down Expand Up @@ -135,7 +138,7 @@ class TextToSpeechProviders(Enum):
}


def text_to_speech_settings(page=None):
def text_to_speech_settings(page):
st.write(
"""
##### 🗣️ Voice Settings
Expand Down Expand Up @@ -227,25 +230,86 @@ def text_to_speech_settings(page=None):

case TextToSpeechProviders.ELEVEN_LABS.name:
with col2:
if not (
page
and (page.is_current_user_paying() or page.is_current_user_admin())
):
st.caption(
if not st.session_state.get("elevenlabs_api_key"):
st.session_state["elevenlabs_api_key"] = page.request.session.get(
SESSION_ELEVENLABS_API_KEY
)

elevenlabs_use_custom_key = st.checkbox(
"Use custom API key + Voice ID",
value=bool(st.session_state.get("elevenlabs_api_key")),
)
if elevenlabs_use_custom_key:
st.session_state["elevenlabs_voice_name"] = None
elevenlabs_api_key = st.text_input(
"""
Note: Please purchase Gooey.AI credits to use ElevenLabs voices
<a href="/account">here</a>.
###### Your ElevenLabs API key
*Read <a target="_blank" href="https://docs.elevenlabs.io/api-reference/authentication">this</a>
to know how to obtain an API key from
ElevenLabs.*
""",
key="elevenlabs_api_key",
)

selected_voice_id = st.session_state.get("elevenlabs_voice_id")
elevenlabs_voices = (
{selected_voice_id: selected_voice_id}
if selected_voice_id
else {}
)

if elevenlabs_api_key:
try:
elevenlabs_voices = fetch_elevenlabs_voices(
elevenlabs_api_key
)
except requests.exceptions.HTTPError as e:
st.error(
f"Invalid ElevenLabs API key. Failed to fetch voices: {e}"
)

st.selectbox(
"""
###### Voice ID (ElevenLabs)
""",
key="elevenlabs_voice_id",
options=elevenlabs_voices.keys(),
format_func=elevenlabs_voices.__getitem__,
)
else:
st.session_state["elevenlabs_api_key"] = None
st.session_state["elevenlabs_voice_id"] = None
if not (
page
and (
page.is_current_user_paying()
or page.is_current_user_admin()
)
):
st.caption(
"""
Note: Please purchase Gooey.AI credits to use ElevenLabs voices
<a href="/account">here</a>.<br/>
Alternatively, you can use your own ElevenLabs API key by selecting the checkbox above.
"""
)

st.session_state.update(
elevenlabs_api_key=None, elevenlabs_voice_id=None
)
st.selectbox(
"""
###### Voice Name (ElevenLabs)
""",
key="elevenlabs_voice_name",
format_func=str,
options=ELEVEN_LABS_VOICES.keys(),
)

st.selectbox(
"""
###### Voice name (ElevenLabs)
""",
key="elevenlabs_voice_name",
format_func=str,
options=ELEVEN_LABS_VOICES.keys(),
page.request.session[SESSION_ELEVENLABS_API_KEY] = st.session_state.get(
"elevenlabs_api_key"
)

st.selectbox(
"""
###### Voice Model
Expand Down Expand Up @@ -323,3 +387,28 @@ def _voice_sort_key(voice: texttospeech.Voice):
# sort alphabetically
voice.name,
)


_elevenlabs_category_order = {
"cloned": 1,
"generated": 2,
"premade": 3,
}


@st.cache_in_session_state
def fetch_elevenlabs_voices(api_key: str) -> dict[str, str]:
r = requests.get(
"https://api.elevenlabs.io/v1/voices",
headers={"Accept": "application/json", "xi-api-key": api_key},
)
r.raise_for_status()
print(r.json()["voices"])
sorted_voices = sorted(
r.json()["voices"],
key=lambda v: (_elevenlabs_category_order.get(v["category"], 0), v["name"]),
)
return {
v["voice_id"]: " - ".join([v["name"], *v["labels"].values()])
for v in sorted_voices
}
27 changes: 27 additions & 0 deletions gooey_ui/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,33 @@ def text_input(
return value or ""


def password_input(
label: str,
value: str = "",
max_chars: str = None,
key: str = None,
help: str = None,
*,
placeholder: str = None,
disabled: bool = False,
label_visibility: LabelVisibility = "visible",
**props,
) -> str:
value = _input_widget(
input_type="password",
label=label,
value=value,
key=key,
help=help,
disabled=disabled,
label_visibility=label_visibility,
maxLength=max_chars,
placeholder=placeholder,
**props,
)
return value or ""


def slider(
label: str,
min_value: float = None,
Expand Down
2 changes: 2 additions & 0 deletions recipes/LipsyncTTS.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class RequestModel(BaseModel):
bark_history_prompt: str | None

elevenlabs_voice_name: str | None
elevenlabs_api_key: str | None
elevenlabs_voice_id: str | None
elevenlabs_model: str | None
elevenlabs_stability: float | None
elevenlabs_similarity_boost: float | None
Expand Down
82 changes: 58 additions & 24 deletions recipes/TextToSpeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ class RequestModel(BaseModel):
bark_history_prompt: str | None

elevenlabs_voice_name: str | None
elevenlabs_api_key: str | None
elevenlabs_voice_id: str | None
elevenlabs_model: str | None
elevenlabs_stability: float | None
elevenlabs_similarity_boost: float | None
Expand Down Expand Up @@ -100,6 +102,12 @@ def render_form_v2(self):
key="text_prompt",
)

def fields_to_save(self):
fields = super().fields_to_save()
if "elevenlabs_api_key" in fields:
fields.remove("elevenlabs_api_key")
return fields

def validate_form_v2(self):
assert st.session_state["text_prompt"], "Text input cannot be empty"

Expand All @@ -110,7 +118,7 @@ def get_raw_price(self, state: dict):
tts_provider = self._get_tts_provider(state)
match tts_provider:
case TextToSpeechProviders.ELEVEN_LABS:
return self._get_eleven_labs_price(state)
return self._get_elevenlabs_price(state)
case _:
return super().get_raw_price(state)

Expand All @@ -126,10 +134,14 @@ def render_output(self):
else:
st.div()

def _get_eleven_labs_price(self, state: dict):
text = state.get("text_prompt", "")
# 0.079 credits / character ~ 4 credits / 10 words
return len(text) * 0.079
def _get_elevenlabs_price(self, state: dict):
_, is_user_provided_key = self._get_elevenlabs_api_key(state)
if is_user_provided_key:
return 0
else:
text = state.get("text_prompt", "")
# 0.079 credits / character ~ 4 credits / 10 words
return len(text) * 0.079

def _get_tts_provider(self, state: dict):
tts_provider = state.get("tts_provider", TextToSpeechProviders.UBERDUCK.name)
Expand All @@ -139,9 +151,11 @@ def _get_tts_provider(self, state: dict):
def additional_notes(self):
tts_provider = st.session_state.get("tts_provider")
if tts_provider == TextToSpeechProviders.ELEVEN_LABS.name:
return """
*Eleven Labs cost ≈ 4 credits per 10 words*
"""
_, is_user_provided_key = self._get_elevenlabs_api_key(st.session_state)
if is_user_provided_key:
return "*Eleven Labs cost ≈ No additional credit charge given we'll use your API key*"
else:
return "*Eleven Labs cost ≈ 4 credits per 10 words*"
else:
return ""

Expand Down Expand Up @@ -239,34 +253,25 @@ def run(self, state: dict):
)

case TextToSpeechProviders.ELEVEN_LABS:
xi_api_key, is_custom_key = self._get_elevenlabs_api_key(state)
assert (
self.is_current_user_paying() or self.is_current_user_admin()
is_custom_key
or self.is_current_user_paying()
or self.is_current_user_admin()
), """
Please purchase Gooey.AI credits to use ElevenLabs voices <a href="/account">here</a>.
"""

# default to first in the mapping
default_voice_model = next(iter(ELEVEN_LABS_MODELS))
default_voice_name = next(iter(ELEVEN_LABS_VOICES))

voice_model = state.get("elevenlabs_model", default_voice_model)
voice_name = state.get("elevenlabs_voice_name", default_voice_name)

# validate voice_model / voice_name
if voice_model not in ELEVEN_LABS_MODELS:
raise ValueError(f"Invalid model: {voice_model}")
if voice_name not in ELEVEN_LABS_VOICES:
raise ValueError(f"Invalid voice_name: {voice_name}")
else:
voice_id = ELEVEN_LABS_VOICES[voice_name]
voice_model = self._get_elevenlabs_voice_model(state)
voice_id = self._get_elevenlabs_voice_id(state)

stability = state.get("elevenlabs_stability", 0.5)
similarity_boost = state.get("elevenlabs_similarity_boost", 0.75)

response = requests.post(
f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}",
headers={
"xi-api-key": settings.ELEVEN_LABS_API_KEY,
"xi-api-key": xi_api_key,
"Accept": "audio/mpeg",
},
json={
Expand All @@ -285,6 +290,35 @@ def run(self, state: dict):
"elevenlabs_gen.mp3", response.content
)

def _get_elevenlabs_voice_model(self, state: dict[str, str]):
default_voice_model = next(iter(ELEVEN_LABS_MODELS))
voice_model = state.get("elevenlabs_model", default_voice_model)
assert voice_model in ELEVEN_LABS_MODELS, f"Invalid model: {voice_model}"
return voice_model

def _get_elevenlabs_voice_id(self, state: dict[str, str]):
if state.get("elevenlabs_voice_id"):
assert state.get(
"elevenlabs_api_key"
), "ElevenLabs API key is required to use a custom voice_id"
return state["elevenlabs_voice_id"]
else:
# default to first in the mapping
default_voice_name = next(iter(ELEVEN_LABS_VOICES))
voice_name = state.get("elevenlabs_voice_name", default_voice_name)
assert voice_name in ELEVEN_LABS_VOICES, f"Invalid voice_name: {voice_name}"
return ELEVEN_LABS_VOICES[voice_name] # voice_name -> voice_id

def _get_elevenlabs_api_key(self, state: dict[str, str]) -> tuple[str, bool]:
"""
Returns the 11labs API key and whether it is a user-provided key or the default key
"""
# ElevenLabs is available for non-paying users with their own API key
if state.get("elevenlabs_api_key"):
return state["elevenlabs_api_key"], True
else:
return settings.ELEVEN_LABS_API_KEY, False

def related_workflows(self) -> list:
from recipes.VideoBots import VideoBotsPage
from recipes.LipsyncTTS import LipsyncTTSPage
Expand Down
Loading

0 comments on commit 7683f0e

Please sign in to comment.