Skip to content

Commit

Permalink
Add multiples references on xtts inference tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Edresson authored and erogol committed Nov 6, 2023
1 parent 1b6f8d0 commit f444f29
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions tests/zoo_tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def test_xtts_streaming():
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts

speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
speaker_wav = [os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")]
speaker_wav_2 = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0002.wav")
speaker_wav.append(speaker_wav_2)
model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1")
config = XttsConfig()
config.load_json(os.path.join(model_path, "config.json"))
Expand Down Expand Up @@ -131,20 +133,21 @@ def test_xtts_v2():
"""XTTS is too big to run on github actions. We need to test it locally"""
output_path = os.path.join(get_tests_output_path(), "output.wav")
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
speaker_wav_2 = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0002.wav")
use_gpu = torch.cuda.is_available()
if use_gpu:
run_cli(
"yes | "
f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 "
f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True '
f'--speaker_wav "{speaker_wav}" --language_idx "en"'
f'--speaker_wav "{speaker_wav}" "{speaker_wav_2}" "--language_idx "en"'
)
else:
run_cli(
"yes | "
f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 "
f'--text "This is an example." --out_path "{output_path}" --progress_bar False '
f'--speaker_wav "{speaker_wav}" --language_idx "en"'
f'--speaker_wav "{speaker_wav}" "{speaker_wav_2}" --language_idx "en"'
)


Expand All @@ -153,7 +156,7 @@ def test_xtts_v2_streaming():
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts

speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
speaker_wav = [os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")]
model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v2")
config = XttsConfig()
config.load_json(os.path.join(model_path, "config.json"))
Expand Down

0 comments on commit f444f29

Please sign in to comment.