-
Notifications
You must be signed in to change notification settings - Fork 0
/
tts_provider.py
107 lines (82 loc) · 2.53 KB
/
tts_provider.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import utils
from enum import Enum
from abc import ABC, abstractmethod
class TTSProvider(ABC):
class Format(Enum):
WAV = 1
MP3 = 2
@abstractmethod
def name(self):
pass
@abstractmethod
def channels(self):
pass
@abstractmethod
def samplerate(self):
pass
@abstractmethod
def dtype(self):
pass
@abstractmethod
def volumegain(self):
return 8
@abstractmethod
def blocksize(self):
pass
@abstractmethod
def format(self) -> Format:
pass
@abstractmethod
def max_length(self):
pass
@abstractmethod
def _list_models(self):
pass
@abstractmethod
def default_voice(self):
pass
@abstractmethod
def default_model(self):
pass
@abstractmethod
def _list_voices(self):
pass
def allows_list_caching(self):
return True
def list_models(self, cache_directory_path):
return utils.list_models(self.name(), self._list_models, self.allows_list_caching(), cache_directory_path)
def list_voices(self, cache_directory_path):
return utils.list_voices(self.name(), self._list_voices, self.allows_list_caching(), cache_directory_path)
def get_voice_name(self, voice):
return voice.split('\t')[1] if '\t' in voice else voice
def get_voice_id(self, voice):
return voice.split('\t')[0] if '\t' in voice else voice
def get_voice_by_name(self, voice_name, cache_directory_path):
if voice_name:
for voice in self.list_voices(cache_directory_path):
name = self.get_voice_name(voice)
if name == voice_name:
return voice
return self.default_voice()
@abstractmethod
def text_to_speech(self, text, model, voice_id, speed, audio_file):
pass
@abstractmethod
def text_to_speech_stream(self, text, model, voice_id, speed, virtual_audio_file):
pass
@abstractmethod
def get_response(self, text, model, voice_id, speed):
pass
def close(self):
pass
def get_tts_providers():
providers = []
from openai_tts_provider import OpenAITTSProvider
providers.append(OpenAITTSProvider())
from elevenlabs_tts_provider import ElevenLabsTTSProvider
providers.append(ElevenLabsTTSProvider())
from playht_tts_provider import PlayHTProvider
providers.append(PlayHTProvider())
from print_tts_provider import PrintTTSProvider
providers.append(PrintTTSProvider())
return providers