diff --git a/tests/zoo_tests/test_models.py b/tests/zoo_tests/test_models.py index 2f9399add8..79aef5cb14 100644 --- a/tests/zoo_tests/test_models.py +++ b/tests/zoo_tests/test_models.py @@ -101,7 +101,9 @@ def test_xtts_streaming(): from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.models.xtts import Xtts - speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav") + speaker_wav = [os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")] + speaker_wav_2 = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0002.wav") + speaker_wav.append(speaker_wav_2) model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1") config = XttsConfig() config.load_json(os.path.join(model_path, "config.json")) @@ -131,20 +133,21 @@ def test_xtts_v2(): """XTTS is too big to run on github actions. We need to test it locally""" output_path = os.path.join(get_tests_output_path(), "output.wav") speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav") + speaker_wav_2 = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0002.wav") use_gpu = torch.cuda.is_available() if use_gpu: run_cli( "yes | " f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 " f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True ' - f'--speaker_wav "{speaker_wav}" --language_idx "en"' + f'--speaker_wav "{speaker_wav}" "{speaker_wav_2}" "--language_idx "en"' ) else: run_cli( "yes | " f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 " f'--text "This is an example." --out_path "{output_path}" --progress_bar False ' - f'--speaker_wav "{speaker_wav}" --language_idx "en"' + f'--speaker_wav "{speaker_wav}" "{speaker_wav_2}" --language_idx "en"' ) @@ -153,7 +156,7 @@ def test_xtts_v2_streaming(): from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.models.xtts import Xtts - speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav") + speaker_wav = [os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")] model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v2") config = XttsConfig() config.load_json(os.path.join(model_path, "config.json"))