Skip to content

Commit

Permalink
0.0.3
Browse files Browse the repository at this point in the history
Added the ability to get a sample rate and also configured automatic assignment.
  • Loading branch information
daswer123 committed May 26, 2024
1 parent f542f32 commit 0b466be
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 10 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,4 @@ cython_debug/
output.wav
silero_tts/latest_silero_models.yml
/tests/tests_temp
test.py
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "silero-tts"
version = "0.0.2"
version = "0.0.3"
authors = [
{ name="daswer123", email="[email protected]" },
]
Expand Down
23 changes: 17 additions & 6 deletions silero_tts/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def main():
parser.add_argument('--language', type=str, help='Language code')
parser.add_argument('--model', type=str, help='Model ID (default: latest version for the language)')
parser.add_argument('--speaker', type=str, help='Speaker name (default: first available speaker for the model)')
parser.add_argument('--sample-rate', type=int, default=48000, help='Sample rate (default: 48000)')
parser.add_argument('--sample-rate', type=int, help='Sample rate (default: highest available for the model)')
parser.add_argument('--device', type=str, default='cpu', help='Device to use (default: cpu)')
parser.add_argument('--text', type=str, help='Text to synthesize')
parser.add_argument('--input-file', type=str, help='Input text file to synthesize')
Expand All @@ -21,13 +21,12 @@ def main():
args = parser.parse_args()

try:

models_file = os.path.join(os.path.dirname(__file__), 'latest_silero_models.yml')

if not os.path.exists(models_file):
logger.warning(f"Models config file not found: {models_file}. Downloading...")
SileroTTS.download_models_config_static(models_file)

if args.list_models:
models = SileroTTS.get_available_models()
logger.info(f"Available models: {models}")
Expand All @@ -43,18 +42,30 @@ def main():

if not args.model:
args.model = SileroTTS.get_latest_model(args.language)
logger.warning(f"Using the latest model for {args.language}: {args.model}")
logger.warning(f"Model not specified. Using the latest model for {args.language}: {args.model}")
logger.info(f"You can specify a different model using the --model flag.")
logger.info(f"Example: --model v4_ru")
logger.info(f"Available models for {args.language}: {', '.join(SileroTTS.get_available_models()[args.language])}")

if not args.sample_rate:
available_sample_rates = SileroTTS.get_available_sample_rates_static(args.language, args.model)
args.sample_rate = max(available_sample_rates)
logger.warning(f"Sample rate not specified. Using the highest available sample rate: {args.sample_rate}")
else:
available_sample_rates = SileroTTS.get_available_sample_rates_static(args.language, args.model)
if args.sample_rate not in available_sample_rates:
logger.warning(f"The specified sample rate {args.sample_rate} is not supported by the model {args.model}.")
logger.info(f"Available sample rates for this model: {', '.join(map(str, available_sample_rates))}")
args.sample_rate = max(available_sample_rates)
logger.info(f"Using the highest available sample rate: {args.sample_rate}")

logger.info(f"Initializing TTS with model: {args.model}, language: {args.language}, speaker: {args.speaker}")
tts = SileroTTS(model_id=args.model, language=args.language, speaker=args.speaker,
sample_rate=args.sample_rate, device=args.device)
logger.success(f"TTS initialized successfully.")

if not args.speaker:
logger.warning(f"Using the default speaker: {tts.speaker}")
logger.warning(f"Speaker not specified. Using the default speaker: {tts.speaker}")
logger.info(f"You can specify a different speaker using the --speaker flag.")
logger.info(f"Example: --speaker aidar")
logger.info(f"Available speakers for model {args.model}: {', '.join(tts.get_available_speakers())}")
Expand Down Expand Up @@ -91,5 +102,5 @@ def main():
except Exception as e:
logger.exception(f"An error occurred: {str(e)}")

if __name__ == '__main__':
if __name__== '__main__':
main()
35 changes: 32 additions & 3 deletions silero_tts/silero_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ def download_models_config(self, models_file=None):

def get_available_speakers(self):
return self.tts_model.speakers

def get_available_sample_rates(self):
model_config = self.models_config['tts_models'][self.language][self.model_id]['latest']
sample_rates = model_config.get('sample_rate', [])

if not isinstance(sample_rates, list):
sample_rates = [sample_rates]

return sample_rates

def validate_model(self):
model_config = self.models_config['tts_models'][self.language][self.model_id]['latest']
Expand Down Expand Up @@ -252,14 +261,14 @@ def get_latest_model(language):
@staticmethod
def get_available_languages():
models_file = os.path.join(os.path.dirname(__file__), 'latest_silero_models.yml')

if not os.path.exists(models_file):
logger.warning(f"Models config file not found: {models_file}. Downloading...")
SileroTTS.download_models_config_static(models_file)

with open(models_file, 'r', encoding='utf-8') as f:
models_config = yaml.safe_load(f)

return list(models_config['tts_models'].keys())


Expand All @@ -280,6 +289,26 @@ def download_models_config_static(models_file=None):
raise Exception(f"Failed to download models config file. Status code: {response.status_code}")


@staticmethod
def get_available_sample_rates_static(language, model_id):
models_file = os.path.join(os.path.dirname(__file__), 'latest_silero_models.yml')

if not os.path.exists(models_file):
logger.warning(f"Models config file not found: {models_file}. Downloading...")
SileroTTS.download_models_config_static(models_file)

with open(models_file, 'r', encoding='utf-8') as f:
models_config = yaml.safe_load(f)

model_config = models_config['tts_models'][language][model_id]['latest']
sample_rates = model_config.get('sample_rate', [])

if not isinstance(sample_rates, list):
sample_rates = [sample_rates]

return sample_rates


if __name__== '__main__':
tts = SileroTTS(model_id='v4_ru',
language='ru',
Expand Down

0 comments on commit 0b466be

Please sign in to comment.