From 52780891d3b6b4ec3ee0cd161904d4e15a4f222f Mon Sep 17 00:00:00 2001 From: rany Date: Fri, 22 Nov 2024 20:02:01 +0200 Subject: [PATCH] Use TypedDict for voice list and TTS stream Signed-off-by: rany --- setup.py | 1 + src/edge_tts/communicate.py | 7 ++-- src/edge_tts/typing.py | 84 +++++++++++++++++++++++++++++++++++++ src/edge_tts/voices.py | 40 +++++++++++------- 4 files changed, 115 insertions(+), 17 deletions(-) create mode 100644 src/edge_tts/typing.py diff --git a/setup.py b/setup.py index d0da9a7..d2d5cd6 100644 --- a/setup.py +++ b/setup.py @@ -6,5 +6,6 @@ install_requires=[ "aiohttp>=3.8.0", "certifi>=2023.11.17", + "typing-extensions>=4.1.0", ], ) diff --git a/src/edge_tts/communicate.py b/src/edge_tts/communicate.py index 382d2e6..5060dad 100644 --- a/src/edge_tts/communicate.py +++ b/src/edge_tts/communicate.py @@ -35,6 +35,7 @@ WebSocketError, ) from .models import TTSConfig +from .typing import TTSChunk def get_headers_and_data( @@ -297,7 +298,7 @@ def __init__( "stream_was_called": False, } - def __parse_metadata(self, data: bytes) -> Dict[str, Any]: + def __parse_metadata(self, data: bytes) -> TTSChunk: for meta_obj in json.loads(data)["Metadata"]: meta_type = meta_obj["Type"] if meta_type == "WordBoundary": @@ -316,7 +317,7 @@ def __parse_metadata(self, data: bytes) -> Dict[str, Any]: raise UnknownResponse(f"Unknown metadata type: {meta_type}") raise UnexpectedResponse("No WordBoundary metadata found") - async def __stream(self) -> AsyncGenerator[Dict[str, Any], None]: + async def __stream(self) -> AsyncGenerator[TTSChunk, None]: async def send_command_request() -> None: """Sends the request to the service.""" @@ -479,7 +480,7 @@ async def send_ssml_request() -> None: async def stream( self, - ) -> AsyncGenerator[Dict[str, Any], None]: + ) -> AsyncGenerator[TTSChunk, None]: """ Streams audio and metadata from the service. diff --git a/src/edge_tts/typing.py b/src/edge_tts/typing.py new file mode 100644 index 0000000..cdd2161 --- /dev/null +++ b/src/edge_tts/typing.py @@ -0,0 +1,84 @@ +"""Custom types for edge-tts.""" + +# pylint: disable=too-few-public-methods + +from typing import List + +from typing_extensions import Literal, NotRequired, TypedDict + + +class TTSChunk(TypedDict): + """TTS chunk data.""" + + type: Literal["audio", "WordBoundary"] + data: NotRequired[bytes] # only for audio + duration: NotRequired[float] # only for WordBoundary + offset: NotRequired[float] # only for WordBoundary + text: NotRequired[str] # only for WordBoundary + + +class VoiceTag(TypedDict): + """VoiceTag data.""" + + ContentCategories: List[ + Literal[ + "Cartoon", + "Conversation", + "Copilot", + "Dialect", + "General", + "News", + "Novel", + "Sports", + ] + ] + VoicePersonalities: List[ + Literal[ + "Approachable", + "Authentic", + "Authority", + "Bright", + "Caring", + "Casual", + "Cheerful", + "Clear", + "Comfort", + "Confident", + "Considerate", + "Conversational", + "Cute", + "Expressive", + "Friendly", + "Honest", + "Humorous", + "Lively", + "Passion", + "Pleasant", + "Positive", + "Professional", + "Rational", + "Reliable", + "Sincere", + "Sunshine", + "Warm", + ] + ] + + +class Voice(TypedDict): + """Voice data.""" + + Name: str + ShortName: str + Gender: Literal["Female", "Male"] + Locale: str + SuggestedCodec: Literal["audio-24khz-48kbitrate-mono-mp3"] + FriendlyName: str + Status: Literal["GA"] + VoiceTag: VoiceTag + + +class VoiceManagerVoice(Voice): + """Voice data for VoiceManager.""" + + Language: str diff --git a/src/edge_tts/voices.py b/src/edge_tts/voices.py index 4bf02e4..f7e60ba 100644 --- a/src/edge_tts/voices.py +++ b/src/edge_tts/voices.py @@ -3,18 +3,19 @@ import json import ssl -from typing import Any, Dict, List, Optional +from typing import Any, List, Optional import aiohttp import certifi from .constants import SEC_MS_GEC_VERSION, VOICE_HEADERS, VOICE_LIST from .drm import DRM +from .typing import Voice, VoiceManagerVoice async def __list_voices( session: aiohttp.ClientSession, ssl_ctx: ssl.SSLContext, proxy: Optional[str] -) -> Any: +) -> List[Voice]: """ Private function that makes the request to the voice list URL and parses the JSON response. This function is used by list_voices() and makes it easier to @@ -26,7 +27,7 @@ async def __list_voices( proxy (Optional[str]): The proxy to use for the request. Returns: - dict: A dictionary of voice attributes. + List[Voice]: A list of voices and their attributes. """ async with session.get( f"{VOICE_LIST}&Sec-MS-GEC={DRM.generate_sec_ms_gec()}" @@ -36,11 +37,25 @@ async def __list_voices( ssl=ssl_ctx, raise_for_status=True, ) as url: - data = json.loads(await url.text()) + data: List[Voice] = json.loads(await url.text()) + + for voice in data: + # Remove leading and trailing whitespace from categories and personalities. + # This has only happened in one case with the zh-CN-YunjianNeural voice + # where there was a leading space in one of the categories. + voice["VoiceTag"]["ContentCategories"] = [ + category.strip() # type: ignore + for category in voice["VoiceTag"]["ContentCategories"] + ] + voice["VoiceTag"]["VoicePersonalities"] = [ + personality.strip() # type: ignore + for personality in voice["VoiceTag"]["VoicePersonalities"] + ] + return data -async def list_voices(*, proxy: Optional[str] = None) -> Any: +async def list_voices(*, proxy: Optional[str] = None) -> List[Voice]: """ List all available voices and their attributes. @@ -51,7 +66,7 @@ async def list_voices(*, proxy: Optional[str] = None) -> Any: proxy (Optional[str]): The proxy to use for the request. Returns: - dict: A dictionary of voice attributes. + List[Voice]: A list of voices and their attributes. """ ssl_ctx = ssl.create_default_context(cafile=certifi.where()) async with aiohttp.ClientSession(trust_env=True) as session: @@ -72,26 +87,23 @@ class VoicesManager: """ def __init__(self) -> None: - self.voices: List[Dict[str, Any]] = [] + self.voices: List[VoiceManagerVoice] = [] self.called_create: bool = False @classmethod - async def create( - cls: Any, custom_voices: Optional[List[Dict[str, Any]]] = None - ) -> Any: + async def create(cls: Any, custom_voices: Optional[List[Voice]] = None) -> Any: """ Creates a VoicesManager object and populates it with all available voices. """ self = VoicesManager() - self.voices = await list_voices() if custom_voices is None else custom_voices + voices = await list_voices() if custom_voices is None else custom_voices self.voices = [ - {**voice, **{"Language": voice["Locale"].split("-")[0]}} - for voice in self.voices + {**voice, "Language": voice["Locale"].split("-")[0]} for voice in voices ] self.called_create = True return self - def find(self, **kwargs: Any) -> List[Dict[str, Any]]: + def find(self, **kwargs: Any) -> List[VoiceManagerVoice]: """ Finds all matching voices based on the provided attributes. """