Skip to content

Commit

Permalink
Use TypedDict for voice list and TTS stream
Browse files Browse the repository at this point in the history
Signed-off-by: rany <[email protected]>
  • Loading branch information
rany2 committed Nov 22, 2024
1 parent 48c7f3a commit 5278089
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 17 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
install_requires=[
"aiohttp>=3.8.0",
"certifi>=2023.11.17",
"typing-extensions>=4.1.0",
],
)
7 changes: 4 additions & 3 deletions src/edge_tts/communicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
WebSocketError,
)
from .models import TTSConfig
from .typing import TTSChunk


def get_headers_and_data(
Expand Down Expand Up @@ -297,7 +298,7 @@ def __init__(
"stream_was_called": False,
}

def __parse_metadata(self, data: bytes) -> Dict[str, Any]:
def __parse_metadata(self, data: bytes) -> TTSChunk:
for meta_obj in json.loads(data)["Metadata"]:
meta_type = meta_obj["Type"]
if meta_type == "WordBoundary":
Expand All @@ -316,7 +317,7 @@ def __parse_metadata(self, data: bytes) -> Dict[str, Any]:
raise UnknownResponse(f"Unknown metadata type: {meta_type}")
raise UnexpectedResponse("No WordBoundary metadata found")

async def __stream(self) -> AsyncGenerator[Dict[str, Any], None]:
async def __stream(self) -> AsyncGenerator[TTSChunk, None]:
async def send_command_request() -> None:
"""Sends the request to the service."""

Expand Down Expand Up @@ -479,7 +480,7 @@ async def send_ssml_request() -> None:

async def stream(
self,
) -> AsyncGenerator[Dict[str, Any], None]:
) -> AsyncGenerator[TTSChunk, None]:
"""
Streams audio and metadata from the service.
Expand Down
84 changes: 84 additions & 0 deletions src/edge_tts/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""Custom types for edge-tts."""

# pylint: disable=too-few-public-methods

from typing import List

from typing_extensions import Literal, NotRequired, TypedDict


class TTSChunk(TypedDict):
"""TTS chunk data."""

type: Literal["audio", "WordBoundary"]
data: NotRequired[bytes] # only for audio
duration: NotRequired[float] # only for WordBoundary
offset: NotRequired[float] # only for WordBoundary
text: NotRequired[str] # only for WordBoundary


class VoiceTag(TypedDict):
"""VoiceTag data."""

ContentCategories: List[
Literal[
"Cartoon",
"Conversation",
"Copilot",
"Dialect",
"General",
"News",
"Novel",
"Sports",
]
]
VoicePersonalities: List[
Literal[
"Approachable",
"Authentic",
"Authority",
"Bright",
"Caring",
"Casual",
"Cheerful",
"Clear",
"Comfort",
"Confident",
"Considerate",
"Conversational",
"Cute",
"Expressive",
"Friendly",
"Honest",
"Humorous",
"Lively",
"Passion",
"Pleasant",
"Positive",
"Professional",
"Rational",
"Reliable",
"Sincere",
"Sunshine",
"Warm",
]
]


class Voice(TypedDict):
"""Voice data."""

Name: str
ShortName: str
Gender: Literal["Female", "Male"]
Locale: str
SuggestedCodec: Literal["audio-24khz-48kbitrate-mono-mp3"]
FriendlyName: str
Status: Literal["GA"]
VoiceTag: VoiceTag


class VoiceManagerVoice(Voice):
"""Voice data for VoiceManager."""

Language: str
40 changes: 26 additions & 14 deletions src/edge_tts/voices.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@

import json
import ssl
from typing import Any, Dict, List, Optional
from typing import Any, List, Optional

import aiohttp
import certifi

from .constants import SEC_MS_GEC_VERSION, VOICE_HEADERS, VOICE_LIST
from .drm import DRM
from .typing import Voice, VoiceManagerVoice


async def __list_voices(
session: aiohttp.ClientSession, ssl_ctx: ssl.SSLContext, proxy: Optional[str]
) -> Any:
) -> List[Voice]:
"""
Private function that makes the request to the voice list URL and parses the
JSON response. This function is used by list_voices() and makes it easier to
Expand All @@ -26,7 +27,7 @@ async def __list_voices(
proxy (Optional[str]): The proxy to use for the request.
Returns:
dict: A dictionary of voice attributes.
List[Voice]: A list of voices and their attributes.
"""
async with session.get(
f"{VOICE_LIST}&Sec-MS-GEC={DRM.generate_sec_ms_gec()}"
Expand All @@ -36,11 +37,25 @@ async def __list_voices(
ssl=ssl_ctx,
raise_for_status=True,
) as url:
data = json.loads(await url.text())
data: List[Voice] = json.loads(await url.text())

for voice in data:
# Remove leading and trailing whitespace from categories and personalities.
# This has only happened in one case with the zh-CN-YunjianNeural voice
# where there was a leading space in one of the categories.
voice["VoiceTag"]["ContentCategories"] = [
category.strip() # type: ignore
for category in voice["VoiceTag"]["ContentCategories"]
]
voice["VoiceTag"]["VoicePersonalities"] = [
personality.strip() # type: ignore
for personality in voice["VoiceTag"]["VoicePersonalities"]
]

return data


async def list_voices(*, proxy: Optional[str] = None) -> Any:
async def list_voices(*, proxy: Optional[str] = None) -> List[Voice]:
"""
List all available voices and their attributes.
Expand All @@ -51,7 +66,7 @@ async def list_voices(*, proxy: Optional[str] = None) -> Any:
proxy (Optional[str]): The proxy to use for the request.
Returns:
dict: A dictionary of voice attributes.
List[Voice]: A list of voices and their attributes.
"""
ssl_ctx = ssl.create_default_context(cafile=certifi.where())
async with aiohttp.ClientSession(trust_env=True) as session:
Expand All @@ -72,26 +87,23 @@ class VoicesManager:
"""

def __init__(self) -> None:
self.voices: List[Dict[str, Any]] = []
self.voices: List[VoiceManagerVoice] = []
self.called_create: bool = False

@classmethod
async def create(
cls: Any, custom_voices: Optional[List[Dict[str, Any]]] = None
) -> Any:
async def create(cls: Any, custom_voices: Optional[List[Voice]] = None) -> Any:
"""
Creates a VoicesManager object and populates it with all available voices.
"""
self = VoicesManager()
self.voices = await list_voices() if custom_voices is None else custom_voices
voices = await list_voices() if custom_voices is None else custom_voices
self.voices = [
{**voice, **{"Language": voice["Locale"].split("-")[0]}}
for voice in self.voices
{**voice, "Language": voice["Locale"].split("-")[0]} for voice in voices
]
self.called_create = True
return self

def find(self, **kwargs: Any) -> List[Dict[str, Any]]:
def find(self, **kwargs: Any) -> List[VoiceManagerVoice]:
"""
Finds all matching voices based on the provided attributes.
"""
Expand Down

0 comments on commit 5278089

Please sign in to comment.