Skip to content

Commit

Permalink
0.0.5
Browse files Browse the repository at this point in the history
Add control for log level
local model download
  • Loading branch information
daswer123 committed Oct 18, 2024
1 parent f8f4607 commit 2ea742c
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 21 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ python -m silero_tts [options]
- `--input-dir INPUT_DIR`: Specify the input directory with text files to synthesize
- `--output-file OUTPUT_FILE`: Specify the output audio file (default: output.wav)
- `--output-dir OUTPUT_DIR`: Specify the output directory for synthesized audio files (default: output)
- `--log-level INFO` : Specify log-level, you can turn off use NONE value (default: INFO)

#### Examples

Expand Down
1 change: 1 addition & 0 deletions README_RU.MD
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ python -m silero_tts [параметры]
- `--input-dir INPUT_DIR`: Укажите входной каталог с текстовыми файлами для синтеза
- `--output-file OUTPUT_FILE`: Укажите выходной аудиофайл (по умолчанию: output.wav)
- `--output-dir OUTPUT_DIR`: Укажите выходной каталог для синтезированных аудиофайлов (по умолчанию: output)
- `--log-level INFO` : Укажите уровень журнала, вы можете отключить его, используя значение NONE. (по умолчанию: INFO)

#### Примеры

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "silero-tts"
version = "0.0.4"
version = "0.0.5"
authors = [
{ name="daswer123", email="[email protected]" },
]
Expand Down
Binary file modified requirements.txt
Binary file not shown.
14 changes: 13 additions & 1 deletion silero_tts/__main__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import os
import sys
from loguru import logger
from tqdm import tqdm
from silero_tts.silero_tts import SileroTTS
Expand All @@ -18,8 +19,19 @@ def main():
parser.add_argument('--input-dir', type=str, help='Input directory with text files to synthesize')
parser.add_argument('--output-file', type=str, default='output.wav', help='Output audio file (default: output.wav)')
parser.add_argument('--output-dir', type=str, default='output', help='Output directory for synthesized audio files (default: output)')
parser.add_argument('--log-level', type=str, default='INFO', help='Logging level (default: INFO)')
args = parser.parse_args()

# Set logging level
logger.remove()


if (args.log_level.upper() == "NONE"):
logger.remove()
# logger = None
else:
logger.add(sys.stderr, level=args.log_level.upper())

try:
models_file = os.path.join(os.path.dirname(__file__), 'latest_silero_models.yml')

Expand Down Expand Up @@ -102,5 +114,5 @@ def main():
except Exception as e:
logger.exception(f"An error occurred: {str(e)}")

if __name__== '__main__':
if __name__ == '__main__':
main()
Binary file added silero_tts/silero_models/v4_ru_ru.pt
Binary file not shown.
55 changes: 36 additions & 19 deletions silero_tts/silero_tts.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
import os
import re
import timeit
from urllib import request
import torch
import sys
import wave
import yaml
import requests
from loguru import logger
from datetime import datetime, timedelta
from number2text.number2text import NumberToText

from silero_tts.lang_data import is_cyrillic, is_latin, lang_data
from silero_tts.transliterate import reverse_transliterate, transliterate

class SileroTTS:
def __init__(self, model_id: str, language: str, speaker: str = None, sample_rate: int = 48000, device: str = 'cpu',
put_accent=True, put_yo=True, num_threads=6):
put_accent=True, put_yo=True, num_threads=6):
self.model_id = model_id
self.language = language
self.sample_rate = sample_rate
Expand All @@ -26,7 +22,7 @@ def __init__(self, model_id: str, language: str, speaker: str = None, sample_rat
self.num_threads = num_threads

self.models_config = self.load_models_config()
self.tts_model, _= self.init_model()
self.tts_model = self.init_model()

if speaker is None:
self.speaker = self.tts_model.speakers[0]
Expand Down Expand Up @@ -152,36 +148,57 @@ def init_model(self):
logger.info("Initializing model")
t0 = timeit.default_timer()

# https://github.com/snakers4/silero-models/issues/183
torch._C._jit_set_profiling_mode(False) # Fixes initial delay

if not torch.cuda.is_available() and self.device == "auto":
self.device = 'cpu'
if torch.cuda.is_available() and self.device == "auto" or self.device == "cuda":
if torch.cuda.is_available() and (self.device == "auto" or self.device == "cuda"):
torch_dev = torch.device("cuda", 0)
gpus_count = torch.cuda.device_count() # 1
gpus_count = torch.cuda.device_count()
logger.info(f"Using {gpus_count} GPU(s)...")
else:
torch_dev = torch.device(self.device)
torch.set_num_threads(self.num_threads)
tts_model, _= torch.hub.load(repo_or_dir='snakers4/silero-models',
model='silero_tts',
language=self.language,
speaker=self.model_id)
logger.info(f"Setup takes {timeit.default_timer() - t0:.2f} seconds")

# Create silero_models directory
silero_models_dir = os.path.join(os.path.dirname(__file__), 'silero_models')
if not os.path.exists(silero_models_dir):
os.makedirs(silero_models_dir)

# Get package URL from models config
package_url = self.models_config['tts_models'][self.language][self.model_id]['latest']['package']

# Define model file path
model_file_name = f"{self.model_id}_{self.language}.pt"
model_file_path = os.path.join(silero_models_dir, model_file_name)

# Download model file if not exists
if not os.path.exists(model_file_path):
logger.info(f"Downloading model from {package_url} to {model_file_path}")
response = requests.get(package_url, stream=True)
if response.status_code == 200:
with open(model_file_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
logger.success(f"Model downloaded successfully.")
else:
logger.error(f"Failed to download model file. Status code: {response.status_code}")
raise Exception(f"Failed to download model file. Status code: {response.status_code}")

# Load model from local file
logger.info("Loading model")
t1 = timeit.default_timer()
tts_model.to(torch_dev) # gpu or cpu
from torch.package import PackageImporter
model = PackageImporter(model_file_path).load_pickle("tts_models", "model")
model.to(torch_dev)
logger.info(f"Model to device takes {timeit.default_timer() - t1:.2f} seconds")

if torch.cuda.is_available() and self.device == "auto" or self.device == "cuda":
if torch.cuda.is_available() and (self.device == "auto" or self.device == "cuda"):
logger.info("Synchronizing CUDA")
t2 = timeit.default_timer()
torch.cuda.synchronize()
logger.info(f"Cuda Synch takes {timeit.default_timer() - t2:.2f} seconds")
logger.success("Model is loaded")
return tts_model, _

return model

def find_char_positions(self, string: str, char: str) -> list:
pos = [] # list to store positions for each 'char' in 'string'
Expand Down

0 comments on commit 2ea742c

Please sign in to comment.