diff --git a/tests/test_bgm_separation.py b/tests/test_bgm_separation.py index c37ac71b..cc4a6f80 100644 --- a/tests/test_bgm_separation.py +++ b/tests/test_bgm_separation.py @@ -6,9 +6,14 @@ import gradio as gr import pytest +import torch import os +@pytest.mark.skipif( + not is_cuda_available(), + reason="Skipping because the test only works on GPU" +) @pytest.mark.parametrize( "whisper_type,vad_filter,bgm_separation,diarization", [ @@ -26,7 +31,10 @@ def test_bgm_separation_pipeline( test_transcribe(whisper_type, vad_filter, bgm_separation, diarization) -@pytest.mark.skip(reason="Too heavy to run in actions with all of other tests") +@pytest.mark.skipif( + not is_cuda_available(), + reason="Skipping because the test only works on GPU" +) @pytest.mark.parametrize( "whisper_type,vad_filter,bgm_separation,diarization", [ diff --git a/tests/test_config.py b/tests/test_config.py index fd22ec71..0f60aa58 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,6 +1,7 @@ from modules.utils.paths import * import os +import torch TEST_FILE_DOWNLOAD_URL = "https://github.com/jhj0517/whisper_flutter_new/raw/main/example/assets/jfk.wav" TEST_FILE_PATH = os.path.join(WEBUI_DIR, "tests", "jfk.wav") @@ -11,3 +12,6 @@ TEST_SUBTITLE_SRT_PATH = os.path.join(WEBUI_DIR, "tests", "test_srt.srt") TEST_SUBTITLE_VTT_PATH = os.path.join(WEBUI_DIR, "tests", "test_vtt.vtt") + +def is_cuda_available(): + return torch.cuda.is_available()