diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 0f89cb4ce..a5a31f4b9 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -15,5 +15,6 @@ jobs: with: python-version: 3.x - run: pip install mkdocs-material + - run: pip install mkdocs-glightbox - run: pip install "mkdocstrings[python]" - run: mkdocs gh-deploy --force \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index f76f7177a..dca7c789e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ asyncio = "*" nest_asyncio = "*" einops = "*" google-generativeai = "*" -torch = "*" +torch = "2.1.0" langchain-experimental = "*" playwright = "*" duckduckgo-search = "*" diff --git a/requirements.txt b/requirements.txt index 82e519af0..e1148c30c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -70,4 +70,4 @@ rich mkdocs mkdocs-material -mkdocs-glightbox +mkdocs-glightbox \ No newline at end of file diff --git a/swarms/models/openai_chat.py b/swarms/models/openai_chat.py index 6ca964a2c..3933d8a79 100644 --- a/swarms/models/openai_chat.py +++ b/swarms/models/openai_chat.py @@ -214,7 +214,7 @@ def is_lc_serializable(cls) -> bool: # Check for classes that derive from this class (as some of them # may assume openai_api_key is a str) # openai_api_key: Optional[str] = Field(default=None, alias="api_key") - openai_api_key = "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" + openai_api_key: Optional[str] = Field(default=None, alias="api_key") """Automatically inferred from env var `OPENAI_API_KEY` if not provided.""" openai_api_base: Optional[str] = Field(default=None, alias="base_url") """Base URL path for API requests, leave blank if not using a proxy or service diff --git a/swarms/models/whisperx.py b/swarms/models/whisperx.py index 1731daa19..102ae7d79 100644 --- a/swarms/models/whisperx.py +++ b/swarms/models/whisperx.py @@ -1 +1,125 @@ -"""An ultra fast speech to text model.""" +# speech to text tool + +import os +import subprocess + +import whisperx +from pydub import AudioSegment +from pytube import YouTube + + +class WhisperX: + def __init__( + self, + video_url, + audio_format="mp3", + device="cuda", + batch_size=16, + compute_type="float16", + hf_api_key=None, + ): + """ + # Example usage + video_url = "url" + speech_to_text = SpeechToText(video_url) + transcription = speech_to_text.transcribe_youtube_video() + print(transcription) + + """ + self.video_url = video_url + self.audio_format = audio_format + self.device = device + self.batch_size = batch_size + self.compute_type = compute_type + self.hf_api_key = hf_api_key + + def install(self): + subprocess.run(["pip", "install", "whisperx"]) + subprocess.run(["pip", "install", "pytube"]) + subprocess.run(["pip", "install", "pydub"]) + + def download_youtube_video(self): + audio_file = f"video.{self.audio_format}" + + # Download video πŸ“₯ + yt = YouTube(self.video_url) + yt_stream = yt.streams.filter(only_audio=True).first() + yt_stream.download(filename="video.mp4") + + # Convert video to audio 🎧 + video = AudioSegment.from_file("video.mp4", format="mp4") + video.export(audio_file, format=self.audio_format) + os.remove("video.mp4") + + return audio_file + + def transcribe_youtube_video(self): + audio_file = self.download_youtube_video() + + device = "cuda" + batch_size = 16 + compute_type = "float16" + + # 1. Transcribe with original Whisper (batched) πŸ—£οΈ + model = whisperx.load_model("large-v2", device, compute_type=compute_type) + audio = whisperx.load_audio(audio_file) + result = model.transcribe(audio, batch_size=batch_size) + + # 2. Align Whisper output πŸ” + model_a, metadata = whisperx.load_align_model( + language_code=result["language"], device=device + ) + result = whisperx.align( + result["segments"], + model_a, + metadata, + audio, + device, + return_char_alignments=False, + ) + + # 3. Assign speaker labels 🏷️ + diarize_model = whisperx.DiarizationPipeline( + use_auth_token=self.hf_api_key, device=device + ) + diarize_model(audio_file) + + try: + segments = result["segments"] + transcription = " ".join(segment["text"] for segment in segments) + return transcription + except KeyError: + print("The key 'segments' is not found in the result.") + + def transcribe(self, audio_file): + model = whisperx.load_model("large-v2", self.device, self.compute_type) + audio = whisperx.load_audio(audio_file) + result = model.transcribe(audio, batch_size=self.batch_size) + + # 2. Align Whisper output πŸ” + model_a, metadata = whisperx.load_align_model( + language_code=result["language"], device=self.device + ) + + result = whisperx.align( + result["segments"], + model_a, + metadata, + audio, + self.device, + return_char_alignments=False, + ) + + # 3. Assign speaker labels 🏷️ + diarize_model = whisperx.DiarizationPipeline( + use_auth_token=self.hf_api_key, device=self.device + ) + + diarize_model(audio_file) + + try: + segments = result["segments"] + transcription = " ".join(segment["text"] for segment in segments) + return transcription + except KeyError: + print("The key 'segments' is not found in the result.") diff --git a/swarms/tools/stt.py b/swarms/tools/stt.py deleted file mode 100644 index cfe3e6567..000000000 --- a/swarms/tools/stt.py +++ /dev/null @@ -1,125 +0,0 @@ -# speech to text tool - -import os -import subprocess - -import whisperx -from pydub import AudioSegment -from pytube import YouTube - - -class SpeechToText: - def __init__( - self, - video_url, - audio_format="mp3", - device="cuda", - batch_size=16, - compute_type="float16", - hf_api_key=None, - ): - """ - # Example usage - video_url = "url" - speech_to_text = SpeechToText(video_url) - transcription = speech_to_text.transcribe_youtube_video() - print(transcription) - - """ - self.video_url = video_url - self.audio_format = audio_format - self.device = device - self.batch_size = batch_size - self.compute_type = compute_type - self.hf_api_key = hf_api_key - - def install(self): - subprocess.run(["pip", "install", "whisperx"]) - subprocess.run(["pip", "install", "pytube"]) - subprocess.run(["pip", "install", "pydub"]) - - def download_youtube_video(self): - audio_file = f"video.{self.audio_format}" - - # Download video πŸ“₯ - yt = YouTube(self.video_url) - yt_stream = yt.streams.filter(only_audio=True).first() - yt_stream.download(filename="video.mp4") - - # Convert video to audio 🎧 - video = AudioSegment.from_file("video.mp4", format="mp4") - video.export(audio_file, format=self.audio_format) - os.remove("video.mp4") - - return audio_file - - def transcribe_youtube_video(self): - audio_file = self.download_youtube_video() - - device = "cuda" - batch_size = 16 - compute_type = "float16" - - # 1. Transcribe with original Whisper (batched) πŸ—£οΈ - model = whisperx.load_model("large-v2", device, compute_type=compute_type) - audio = whisperx.load_audio(audio_file) - result = model.transcribe(audio, batch_size=batch_size) - - # 2. Align Whisper output πŸ” - model_a, metadata = whisperx.load_align_model( - language_code=result["language"], device=device - ) - result = whisperx.align( - result["segments"], - model_a, - metadata, - audio, - device, - return_char_alignments=False, - ) - - # 3. Assign speaker labels 🏷️ - diarize_model = whisperx.DiarizationPipeline( - use_auth_token=self.hf_api_key, device=device - ) - diarize_model(audio_file) - - try: - segments = result["segments"] - transcription = " ".join(segment["text"] for segment in segments) - return transcription - except KeyError: - print("The key 'segments' is not found in the result.") - - def transcribe(self, audio_file): - model = whisperx.load_model("large-v2", self.device, self.compute_type) - audio = whisperx.load_audio(audio_file) - result = model.transcribe(audio, batch_size=self.batch_size) - - # 2. Align Whisper output πŸ” - model_a, metadata = whisperx.load_align_model( - language_code=result["language"], device=self.device - ) - - result = whisperx.align( - result["segments"], - model_a, - metadata, - audio, - self.device, - return_char_alignments=False, - ) - - # 3. Assign speaker labels 🏷️ - diarize_model = whisperx.DiarizationPipeline( - use_auth_token=self.hf_api_key, device=self.device - ) - - diarize_model(audio_file) - - try: - segments = result["segments"] - transcription = " ".join(segment["text"] for segment in segments) - return transcription - except KeyError: - print("The key 'segments' is not found in the result.") diff --git a/tests/models/kosmos2.py b/tests/models/kosmos2.py new file mode 100644 index 000000000..2ff010927 --- /dev/null +++ b/tests/models/kosmos2.py @@ -0,0 +1,365 @@ +import pytest +import os +from PIL import Image +from swarms.models.kosmos2 import Kosmos2, Detections + + +# Fixture for a sample image +@pytest.fixture +def sample_image(): + image = Image.new("RGB", (224, 224)) + return image + + +# Fixture for initializing Kosmos2 +@pytest.fixture +def kosmos2(): + return Kosmos2.initialize() + + +# Test Kosmos2 initialization +def test_kosmos2_initialization(kosmos2): + assert kosmos2 is not None + + +# Test Kosmos2 with a sample image +def test_kosmos2_with_sample_image(kosmos2, sample_image): + detections = kosmos2(img=sample_image) + assert isinstance(detections, Detections) + assert ( + len(detections.xyxy) + == len(detections.class_id) + == len(detections.confidence) + == 0 + ) + + +# Mocked extract_entities function for testing +def mock_extract_entities(text): + return [("entity1", (0.1, 0.2, 0.3, 0.4)), ("entity2", (0.5, 0.6, 0.7, 0.8))] + + +# Mocked process_entities_to_detections function for testing +def mock_process_entities_to_detections(entities, image): + return Detections( + xyxy=[(10, 20, 30, 40), (50, 60, 70, 80)], + class_id=[0, 0], + confidence=[1.0, 1.0], + ) + + +# Test Kosmos2 with mocked entity extraction and detection +def test_kosmos2_with_mocked_extraction_and_detection( + kosmos2, sample_image, monkeypatch +): + monkeypatch.setattr(kosmos2, "extract_entities", mock_extract_entities) + monkeypatch.setattr( + kosmos2, "process_entities_to_detections", mock_process_entities_to_detections + ) + + detections = kosmos2(img=sample_image) + assert isinstance(detections, Detections) + assert ( + len(detections.xyxy) + == len(detections.class_id) + == len(detections.confidence) + == 2 + ) + + +# Test Kosmos2 with empty entity extraction +def test_kosmos2_with_empty_extraction(kosmos2, sample_image, monkeypatch): + monkeypatch.setattr(kosmos2, "extract_entities", lambda x: []) + detections = kosmos2(img=sample_image) + assert isinstance(detections, Detections) + assert ( + len(detections.xyxy) + == len(detections.class_id) + == len(detections.confidence) + == 0 + ) + + +# Test Kosmos2 with invalid image path +def test_kosmos2_with_invalid_image_path(kosmos2): + with pytest.raises(Exception): + kosmos2(img="invalid_image_path.jpg") + + +# Additional tests can be added for various scenarios and edge cases + + +# Test Kosmos2 with a larger image +def test_kosmos2_with_large_image(kosmos2): + large_image = Image.new("RGB", (1024, 768)) + detections = kosmos2(img=large_image) + assert isinstance(detections, Detections) + assert ( + len(detections.xyxy) + == len(detections.class_id) + == len(detections.confidence) + == 0 + ) + + +# Test Kosmos2 with different image formats +def test_kosmos2_with_different_image_formats(kosmos2, tmp_path): + # Create a temporary directory + temp_dir = tmp_path / "images" + temp_dir.mkdir() + + # Create sample images in different formats + image_formats = ["jpeg", "png", "gif", "bmp"] + for format in image_formats: + image_path = temp_dir / f"sample_image.{format}" + Image.new("RGB", (224, 224)).save(image_path) + + # Test Kosmos2 with each image format + for format in image_formats: + image_path = temp_dir / f"sample_image.{format}" + detections = kosmos2(img=image_path) + assert isinstance(detections, Detections) + assert ( + len(detections.xyxy) + == len(detections.class_id) + == len(detections.confidence) + == 0 + ) + + +# Test Kosmos2 with a non-existent model +def test_kosmos2_with_non_existent_model(kosmos2): + with pytest.raises(Exception): + kosmos2.model = None + kosmos2(img="sample_image.jpg") + + +# Test Kosmos2 with a non-existent processor +def test_kosmos2_with_non_existent_processor(kosmos2): + with pytest.raises(Exception): + kosmos2.processor = None + kosmos2(img="sample_image.jpg") + + +# Test Kosmos2 with missing image +def test_kosmos2_with_missing_image(kosmos2): + with pytest.raises(Exception): + kosmos2(img="non_existent_image.jpg") + + +# ... (previous tests) + + +# Test Kosmos2 with a non-existent model and processor +def test_kosmos2_with_non_existent_model_and_processor(kosmos2): + with pytest.raises(Exception): + kosmos2.model = None + kosmos2.processor = None + kosmos2(img="sample_image.jpg") + + +# Test Kosmos2 with a corrupted image +def test_kosmos2_with_corrupted_image(kosmos2, tmp_path): + # Create a temporary directory + temp_dir = tmp_path / "images" + temp_dir.mkdir() + + # Create a corrupted image + corrupted_image_path = temp_dir / "corrupted_image.jpg" + with open(corrupted_image_path, "wb") as f: + f.write(b"corrupted data") + + with pytest.raises(Exception): + kosmos2(img=corrupted_image_path) + + +# Test Kosmos2 with a large batch size +def test_kosmos2_with_large_batch_size(kosmos2, sample_image): + kosmos2.batch_size = 32 + detections = kosmos2(img=sample_image) + assert isinstance(detections, Detections) + assert ( + len(detections.xyxy) + == len(detections.class_id) + == len(detections.confidence) + == 0 + ) + + +# Test Kosmos2 with an invalid compute type +def test_kosmos2_with_invalid_compute_type(kosmos2, sample_image): + kosmos2.compute_type = "invalid_compute_type" + with pytest.raises(Exception): + kosmos2(img=sample_image) + + +# Test Kosmos2 with a valid HF API key +def test_kosmos2_with_valid_hf_api_key(kosmos2, sample_image): + kosmos2.hf_api_key = "valid_api_key" + detections = kosmos2(img=sample_image) + assert isinstance(detections, Detections) + assert ( + len(detections.xyxy) + == len(detections.class_id) + == len(detections.confidence) + == 2 + ) + + +# Test Kosmos2 with an invalid HF API key +def test_kosmos2_with_invalid_hf_api_key(kosmos2, sample_image): + kosmos2.hf_api_key = "invalid_api_key" + with pytest.raises(Exception): + kosmos2(img=sample_image) + + +# Test Kosmos2 with a very long generated text +def test_kosmos2_with_long_generated_text(kosmos2, sample_image, monkeypatch): + def mock_generate_text(*args, **kwargs): + return "A" * 10000 + + monkeypatch.setattr(kosmos2.model, "generate", mock_generate_text) + detections = kosmos2(img=sample_image) + assert isinstance(detections, Detections) + assert ( + len(detections.xyxy) + == len(detections.class_id) + == len(detections.confidence) + == 0 + ) + + +# Test Kosmos2 with entities containing special characters +def test_kosmos2_with_entities_containing_special_characters( + kosmos2, sample_image, monkeypatch +): + def mock_extract_entities(text): + return [("entity1 with special characters (ΓΌ, ΓΆ, etc.)", (0.1, 0.2, 0.3, 0.4))] + + monkeypatch.setattr(kosmos2, "extract_entities", mock_extract_entities) + detections = kosmos2(img=sample_image) + assert isinstance(detections, Detections) + assert ( + len(detections.xyxy) + == len(detections.class_id) + == len(detections.confidence) + == 1 + ) + + +# Test Kosmos2 with image containing multiple objects +def test_kosmos2_with_image_containing_multiple_objects( + kosmos2, sample_image, monkeypatch +): + def mock_extract_entities(text): + return [("entity1", (0.1, 0.2, 0.3, 0.4)), ("entity2", (0.5, 0.6, 0.7, 0.8))] + + monkeypatch.setattr(kosmos2, "extract_entities", mock_extract_entities) + detections = kosmos2(img=sample_image) + assert isinstance(detections, Detections) + assert ( + len(detections.xyxy) + == len(detections.class_id) + == len(detections.confidence) + == 2 + ) + + +# Test Kosmos2 with image containing no objects +def test_kosmos2_with_image_containing_no_objects(kosmos2, sample_image, monkeypatch): + def mock_extract_entities(text): + return [] + + monkeypatch.setattr(kosmos2, "extract_entities", mock_extract_entities) + detections = kosmos2(img=sample_image) + assert isinstance(detections, Detections) + assert ( + len(detections.xyxy) + == len(detections.class_id) + == len(detections.confidence) + == 0 + ) + + +# Test Kosmos2 with a valid YouTube video URL +def test_kosmos2_with_valid_youtube_video_url(kosmos2): + youtube_video_url = "https://www.youtube.com/watch?v=VIDEO_ID" + detections = kosmos2(video_url=youtube_video_url) + assert isinstance(detections, Detections) + assert ( + len(detections.xyxy) + == len(detections.class_id) + == len(detections.confidence) + == 2 + ) + + +# Test Kosmos2 with an invalid YouTube video URL +def test_kosmos2_with_invalid_youtube_video_url(kosmos2): + invalid_youtube_video_url = "https://www.youtube.com/invalid_video" + with pytest.raises(Exception): + kosmos2(video_url=invalid_youtube_video_url) + + +# Test Kosmos2 with no YouTube video URL provided +def test_kosmos2_with_no_youtube_video_url(kosmos2): + with pytest.raises(Exception): + kosmos2(video_url=None) + + +# Test Kosmos2 installation +def test_kosmos2_installation(): + kosmos2 = Kosmos2() + kosmos2.install() + assert os.path.exists("video.mp4") + assert os.path.exists("video.mp3") + os.remove("video.mp4") + os.remove("video.mp3") + + +# Test Kosmos2 termination +def test_kosmos2_termination(kosmos2): + kosmos2.terminate() + assert kosmos2.process is None + + +# Test Kosmos2 start_process method +def test_kosmos2_start_process(kosmos2): + kosmos2.start_process() + assert kosmos2.process is not None + + +# Test Kosmos2 preprocess_code method +def test_kosmos2_preprocess_code(kosmos2): + code = "print('Hello, World!')" + preprocessed_code = kosmos2.preprocess_code(code) + assert isinstance(preprocessed_code, str) + assert "end_of_execution" in preprocessed_code + + +# Test Kosmos2 run method with debug mode +def test_kosmos2_run_with_debug_mode(kosmos2, sample_image): + kosmos2.debug_mode = True + detections = kosmos2(img=sample_image) + assert isinstance(detections, Detections) + + +# Test Kosmos2 handle_stream_output method +def test_kosmos2_handle_stream_output(kosmos2): + stream_output = "Sample output" + kosmos2.handle_stream_output(stream_output, is_error=False) + + +# Test Kosmos2 run method with invalid image path +def test_kosmos2_run_with_invalid_image_path(kosmos2): + with pytest.raises(Exception): + kosmos2.run(img="invalid_image_path.jpg") + + +# Test Kosmos2 run method with invalid video URL +def test_kosmos2_run_with_invalid_video_url(kosmos2): + with pytest.raises(Exception): + kosmos2.run(video_url="invalid_video_url") + + +# ... (more tests) diff --git a/tests/models/whisperx.py b/tests/models/whisperx.py new file mode 100644 index 000000000..17a288571 --- /dev/null +++ b/tests/models/whisperx.py @@ -0,0 +1,206 @@ +import os +import subprocess +import tempfile +from unittest.mock import Mock, patch + +import pytest +import whisperx +from pydub import AudioSegment +from pytube import YouTube +from your_module import SpeechToText + +from swarms.models.whisperx import SpeechToText + + +# Fixture to create a temporary directory for testing +@pytest.fixture +def temp_dir(): + with tempfile.TemporaryDirectory() as tempdir: + yield tempdir + + +# Mock subprocess.run to prevent actual installation during tests +@patch.object(subprocess, "run") +def test_speech_to_text_install(mock_run): + stt = SpeechToText("https://www.youtube.com/watch?v=MJd6pr16LRM") + stt.install() + mock_run.assert_called_with(["pip", "install", "whisperx"]) + + +# Mock pytube.YouTube and pytube.Streams for download tests +@patch("pytube.YouTube") +@patch.object(YouTube, "streams") +def test_speech_to_text_download_youtube_video(mock_streams, mock_youtube, temp_dir): + # Mock YouTube and streams + video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM" + mock_stream = mock_streams().filter().first() + mock_stream.download.return_value = os.path.join(temp_dir, "video.mp4") + mock_youtube.return_value = mock_youtube + mock_youtube.streams = mock_streams + + stt = SpeechToText(video_url) + audio_file = stt.download_youtube_video() + + assert os.path.exists(audio_file) + assert audio_file.endswith(".mp3") + + +# Mock whisperx.load_model and whisperx.load_audio for transcribe tests +@patch("whisperx.load_model") +@patch("whisperx.load_audio") +@patch("whisperx.load_align_model") +@patch("whisperx.align") +@patch.object(whisperx.DiarizationPipeline, "__call__") +def test_speech_to_text_transcribe_youtube_video( + mock_diarization, + mock_align, + mock_align_model, + mock_load_audio, + mock_load_model, + temp_dir, +): + # Mock whisperx functions + mock_load_model.return_value = mock_load_model + mock_load_model.transcribe.return_value = { + "language": "en", + "segments": [{"text": "Hello, World!"}], + } + + mock_load_audio.return_value = "audio_path" + mock_align_model.return_value = (mock_align_model, "metadata") + mock_align.return_value = {"segments": [{"text": "Hello, World!"}]} + + # Mock diarization pipeline + mock_diarization.return_value = None + + video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM/video" + stt = SpeechToText(video_url) + transcription = stt.transcribe_youtube_video() + + assert transcription == "Hello, World!" + + +# More tests for different scenarios and edge cases can be added here. + + +# Test transcribe method with provided audio file +def test_speech_to_text_transcribe_audio_file(temp_dir): + # Create a temporary audio file + audio_file = os.path.join(temp_dir, "test_audio.mp3") + AudioSegment.silent(duration=500).export(audio_file, format="mp3") + + stt = SpeechToText("https://www.youtube.com/watch?v=MJd6pr16LRM") + transcription = stt.transcribe(audio_file) + + assert transcription == "" + + +# Test transcribe method when Whisperx fails +@patch("whisperx.load_model") +@patch("whisperx.load_audio") +def test_speech_to_text_transcribe_whisperx_failure( + mock_load_audio, mock_load_model, temp_dir +): + # Mock whisperx functions to raise an exception + mock_load_model.side_effect = Exception("Whisperx failed") + mock_load_audio.return_value = "audio_path" + + stt = SpeechToText("https://www.youtube.com/watch?v=MJd6pr16LRM") + transcription = stt.transcribe("audio_path") + + assert transcription == "Whisperx failed" + + +# Test transcribe method with missing 'segments' key in Whisperx output +@patch("whisperx.load_model") +@patch("whisperx.load_audio") +@patch("whisperx.load_align_model") +@patch("whisperx.align") +@patch.object(whisperx.DiarizationPipeline, "__call__") +def test_speech_to_text_transcribe_missing_segments( + mock_diarization, mock_align, mock_align_model, mock_load_audio, mock_load_model +): + # Mock whisperx functions to return incomplete output + mock_load_model.return_value = mock_load_model + mock_load_model.transcribe.return_value = {"language": "en"} + + mock_load_audio.return_value = "audio_path" + mock_align_model.return_value = (mock_align_model, "metadata") + mock_align.return_value = {} + + # Mock diarization pipeline + mock_diarization.return_value = None + + stt = SpeechToText("https://www.youtube.com/watch?v=MJd6pr16LRM") + transcription = stt.transcribe("audio_path") + + assert transcription == "" + + +# Test transcribe method with Whisperx align failure +@patch("whisperx.load_model") +@patch("whisperx.load_audio") +@patch("whisperx.load_align_model") +@patch("whisperx.align") +@patch.object(whisperx.DiarizationPipeline, "__call__") +def test_speech_to_text_transcribe_align_failure( + mock_diarization, mock_align, mock_align_model, mock_load_audio, mock_load_model +): + # Mock whisperx functions to raise an exception during align + mock_load_model.return_value = mock_load_model + mock_load_model.transcribe.return_value = { + "language": "en", + "segments": [{"text": "Hello, World!"}], + } + + mock_load_audio.return_value = "audio_path" + mock_align_model.return_value = (mock_align_model, "metadata") + mock_align.side_effect = Exception("Align failed") + + # Mock diarization pipeline + mock_diarization.return_value = None + + stt = SpeechToText("https://www.youtube.com/watch?v=MJd6pr16LRM") + transcription = stt.transcribe("audio_path") + + assert transcription == "Align failed" + + +# Test transcribe_youtube_video when Whisperx diarization fails +@patch("pytube.YouTube") +@patch.object(YouTube, "streams") +@patch("whisperx.DiarizationPipeline") +@patch("whisperx.load_audio") +@patch("whisperx.load_align_model") +@patch("whisperx.align") +def test_speech_to_text_transcribe_diarization_failure( + mock_align, + mock_align_model, + mock_load_audio, + mock_diarization, + mock_streams, + mock_youtube, + temp_dir, +): + # Mock YouTube and streams + video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM" + mock_stream = mock_streams().filter().first() + mock_stream.download.return_value = os.path.join(temp_dir, "video.mp4") + mock_youtube.return_value = mock_youtube + mock_youtube.streams = mock_streams + + # Mock whisperx functions + mock_load_audio.return_value = "audio_path" + mock_align_model.return_value = (mock_align_model, "metadata") + mock_align.return_value = {"segments": [{"text": "Hello, World!"}]} + + # Mock diarization pipeline to raise an exception + mock_diarization.side_effect = Exception("Diarization failed") + + stt = SpeechToText(video_url) + transcription = stt.transcribe_youtube_video() + + assert transcription == "Diarization failed" + + +# Add more tests for other scenarios and edge cases as needed. diff --git a/tests/tools/base.py b/tests/tools/base.py new file mode 100644 index 000000000..4f7e2b4b3 --- /dev/null +++ b/tests/tools/base.py @@ -0,0 +1,802 @@ +from unittest.mock import MagicMock + +import pytest +from pydantic import BaseModel + +from swarms.tools.tool import BaseTool, Runnable, StructuredTool, Tool, tool + +# Define test data +test_input = {"key1": "value1", "key2": "value2"} +expected_output = "expected_output_value" + +# Test with global variables +global_var = "global" + + +# Basic tests for BaseTool +def test_base_tool_init(): + # Test BaseTool initialization + tool = BaseTool() + assert isinstance(tool, BaseTool) + + +def test_base_tool_invoke(): + # Test BaseTool invoke method + tool = BaseTool() + result = tool.invoke(test_input) + assert result == expected_output + + +# Basic tests for Tool +def test_tool_init(): + # Test Tool initialization + tool = Tool() + assert isinstance(tool, Tool) + + +def test_tool_invoke(): + # Test Tool invoke method + tool = Tool() + result = tool.invoke(test_input) + assert result == expected_output + + +# Basic tests for StructuredTool +def test_structured_tool_init(): + # Test StructuredTool initialization + tool = StructuredTool() + assert isinstance(tool, StructuredTool) + + +def test_structured_tool_invoke(): + # Test StructuredTool invoke method + tool = StructuredTool() + result = tool.invoke(test_input) + assert result == expected_output + + +# Test additional functionality and edge cases as needed + + +def test_tool_creation(): + tool = Tool(name="test_tool", func=lambda x: x, description="Test tool") + assert tool.name == "test_tool" + assert tool.func is not None + assert tool.description == "Test tool" + + +def test_tool_ainvoke(): + tool = Tool(name="test_tool", func=lambda x: x, description="Test tool") + result = tool.ainvoke("input_data") + assert result == "input_data" + + +def test_tool_ainvoke_with_coroutine(): + async def async_function(input_data): + return input_data + + tool = Tool(name="test_tool", coroutine=async_function, description="Test tool") + result = tool.ainvoke("input_data") + assert result == "input_data" + + +def test_tool_args(): + def sample_function(input_data): + return input_data + + tool = Tool(name="test_tool", func=sample_function, description="Test tool") + assert tool.args == {"tool_input": {"type": "string"}} + + +# Basic tests for StructuredTool class + + +def test_structured_tool_creation(): + class SampleArgsSchema: + pass + + tool = StructuredTool( + name="test_tool", + func=lambda x: x, + description="Test tool", + args_schema=SampleArgsSchema, + ) + assert tool.name == "test_tool" + assert tool.func is not None + assert tool.description == "Test tool" + assert tool.args_schema == SampleArgsSchema + + +def test_structured_tool_ainvoke(): + class SampleArgsSchema: + pass + + tool = StructuredTool( + name="test_tool", + func=lambda x: x, + description="Test tool", + args_schema=SampleArgsSchema, + ) + result = tool.ainvoke({"tool_input": "input_data"}) + assert result == "input_data" + + +def test_structured_tool_ainvoke_with_coroutine(): + class SampleArgsSchema: + pass + + async def async_function(input_data): + return input_data + + tool = StructuredTool( + name="test_tool", + coroutine=async_function, + description="Test tool", + args_schema=SampleArgsSchema, + ) + result = tool.ainvoke({"tool_input": "input_data"}) + assert result == "input_data" + + +def test_structured_tool_args(): + class SampleArgsSchema: + pass + + def sample_function(input_data): + return input_data + + tool = StructuredTool( + name="test_tool", + func=sample_function, + description="Test tool", + args_schema=SampleArgsSchema, + ) + assert tool.args == {"tool_input": {"type": "string"}} + + +# Additional tests for exception handling + + +def test_tool_ainvoke_exception(): + tool = Tool(name="test_tool", func=None, description="Test tool") + with pytest.raises(NotImplementedError): + tool.ainvoke("input_data") + + +def test_tool_ainvoke_with_coroutine_exception(): + tool = Tool(name="test_tool", coroutine=None, description="Test tool") + with pytest.raises(NotImplementedError): + tool.ainvoke("input_data") + + +def test_structured_tool_ainvoke_exception(): + class SampleArgsSchema: + pass + + tool = StructuredTool( + name="test_tool", + func=None, + description="Test tool", + args_schema=SampleArgsSchema, + ) + with pytest.raises(NotImplementedError): + tool.ainvoke({"tool_input": "input_data"}) + + +def test_structured_tool_ainvoke_with_coroutine_exception(): + class SampleArgsSchema: + pass + + tool = StructuredTool( + name="test_tool", + coroutine=None, + description="Test tool", + args_schema=SampleArgsSchema, + ) + with pytest.raises(NotImplementedError): + tool.ainvoke({"tool_input": "input_data"}) + + +def test_tool_description_not_provided(): + tool = Tool(name="test_tool", func=lambda x: x) + assert tool.name == "test_tool" + assert tool.func is not None + assert tool.description == "" + + +def test_tool_invoke_with_callbacks(): + def sample_function(input_data, callbacks=None): + if callbacks: + callbacks.on_start() + callbacks.on_finish() + return input_data + + tool = Tool(name="test_tool", func=sample_function) + callbacks = MagicMock() + result = tool.invoke("input_data", callbacks=callbacks) + assert result == "input_data" + callbacks.on_start.assert_called_once() + callbacks.on_finish.assert_called_once() + + +def test_tool_invoke_with_new_argument(): + def sample_function(input_data, callbacks=None): + return input_data + + tool = Tool(name="test_tool", func=sample_function) + result = tool.invoke("input_data", callbacks=None) + assert result == "input_data" + + +def test_tool_ainvoke_with_new_argument(): + async def async_function(input_data, callbacks=None): + return input_data + + tool = Tool(name="test_tool", coroutine=async_function) + result = tool.ainvoke("input_data", callbacks=None) + assert result == "input_data" + + +def test_tool_description_from_docstring(): + def sample_function(input_data): + """Sample function docstring""" + return input_data + + tool = Tool(name="test_tool", func=sample_function) + assert tool.description == "Sample function docstring" + + +def test_tool_ainvoke_with_exceptions(): + async def async_function(input_data): + raise ValueError("Test exception") + + tool = Tool(name="test_tool", coroutine=async_function) + with pytest.raises(ValueError): + tool.ainvoke("input_data") + + +# Additional tests for StructuredTool class + + +def test_structured_tool_infer_schema_false(): + def sample_function(input_data): + return input_data + + tool = StructuredTool( + name="test_tool", + func=sample_function, + args_schema=None, + infer_schema=False, + ) + assert tool.args_schema is None + + +def test_structured_tool_ainvoke_with_callbacks(): + class SampleArgsSchema: + pass + + def sample_function(input_data, callbacks=None): + if callbacks: + callbacks.on_start() + callbacks.on_finish() + return input_data + + tool = StructuredTool( + name="test_tool", + func=sample_function, + args_schema=SampleArgsSchema, + ) + callbacks = MagicMock() + result = tool.ainvoke({"tool_input": "input_data"}, callbacks=callbacks) + assert result == "input_data" + callbacks.on_start.assert_called_once() + callbacks.on_finish.assert_called_once() + + +def test_structured_tool_description_not_provided(): + class SampleArgsSchema: + pass + + tool = StructuredTool( + name="test_tool", + func=lambda x: x, + args_schema=SampleArgsSchema, + ) + assert tool.name == "test_tool" + assert tool.func is not None + assert tool.description == "" + + +def test_structured_tool_args_schema(): + class SampleArgsSchema: + pass + + def sample_function(input_data): + return input_data + + tool = StructuredTool( + name="test_tool", + func=sample_function, + args_schema=SampleArgsSchema, + ) + assert tool.args_schema == SampleArgsSchema + + +def test_structured_tool_args_schema_inference(): + def sample_function(input_data): + return input_data + + tool = StructuredTool( + name="test_tool", + func=sample_function, + args_schema=None, + infer_schema=True, + ) + assert tool.args_schema is not None + + +def test_structured_tool_ainvoke_with_new_argument(): + class SampleArgsSchema: + pass + + def sample_function(input_data, callbacks=None): + return input_data + + tool = StructuredTool( + name="test_tool", + func=sample_function, + args_schema=SampleArgsSchema, + ) + result = tool.ainvoke({"tool_input": "input_data"}, callbacks=None) + assert result == "input_data" + + +def test_structured_tool_ainvoke_with_exceptions(): + class SampleArgsSchema: + pass + + async def async_function(input_data): + raise ValueError("Test exception") + + tool = StructuredTool( + name="test_tool", + coroutine=async_function, + args_schema=SampleArgsSchema, + ) + with pytest.raises(ValueError): + tool.ainvoke({"tool_input": "input_data"}) + + +# Test additional functionality and edge cases +def test_tool_with_fixture(some_fixture): + # Test Tool with a fixture + tool = Tool() + result = tool.invoke(test_input) + assert result == expected_output + + +def test_structured_tool_with_fixture(some_fixture): + # Test StructuredTool with a fixture + tool = StructuredTool() + result = tool.invoke(test_input) + assert result == expected_output + + +def test_base_tool_verbose_logging(caplog): + # Test verbose logging in BaseTool + tool = BaseTool(verbose=True) + result = tool.invoke(test_input) + assert result == expected_output + assert "Verbose logging" in caplog.text + + +def test_tool_exception_handling(): + # Test exception handling in Tool + tool = Tool() + with pytest.raises(Exception): + tool.invoke(test_input, raise_exception=True) + + +def test_structured_tool_async_invoke(): + # Test asynchronous invoke in StructuredTool + tool = StructuredTool() + result = tool.ainvoke(test_input) + assert result == expected_output + + +def test_tool_async_invoke_with_fixture(some_fixture): + # Test asynchronous invoke with a fixture in Tool + tool = Tool() + result = tool.ainvoke(test_input) + assert result == expected_output + + +# Add more tests for specific functionalities and edge cases as needed +# Import necessary libraries and modules + + +# Example of a mock function to be used in testing +def mock_function(arg: str) -> str: + """A simple mock function for testing.""" + return f"Processed {arg}" + + +# Example of a Runnable class for testing +class MockRunnable(Runnable): + # Define necessary methods and properties + pass + + +# Fixture for creating a mock function +@pytest.fixture +def mock_func(): + return mock_function + + +# Fixture for creating a Runnable instance +@pytest.fixture +def mock_runnable(): + return MockRunnable() + + +# Basic functionality tests +def test_tool_with_callable(mock_func): + # Test creating a tool with a simple callable + tool_instance = tool(mock_func) + assert isinstance(tool_instance, BaseTool) + + +def test_tool_with_runnable(mock_runnable): + # Test creating a tool with a Runnable instance + tool_instance = tool(mock_runnable) + assert isinstance(tool_instance, BaseTool) + + +# ... more basic functionality tests ... + + +# Argument handling tests +def test_tool_with_invalid_argument(): + # Test passing an invalid argument type + with pytest.raises(ValueError): + tool(123) # Using an integer instead of a string/callable/Runnable + + +def test_tool_with_multiple_arguments(mock_func): + # Test passing multiple valid arguments + tool_instance = tool("mock", mock_func) + assert isinstance(tool_instance, BaseTool) + + +# ... more argument handling tests ... + + +# Schema inference and application tests +class TestSchema(BaseModel): + arg: str + + +def test_tool_with_args_schema(mock_func): + # Test passing a custom args_schema + tool_instance = tool(mock_func, args_schema=TestSchema) + assert tool_instance.args_schema == TestSchema + + +# ... more schema tests ... + + +# Exception handling tests +def test_tool_function_without_docstring(): + # Test that a ValueError is raised if the function lacks a docstring + def no_doc_func(arg: str) -> str: + return arg + + with pytest.raises(ValueError): + tool(no_doc_func) + + +# ... more exception tests ... + + +# Decorator behavior tests +@pytest.mark.asyncio +async def test_async_tool_function(): + # Test an async function with the tool decorator + @tool + async def async_func(arg: str) -> str: + return arg + + # Add async specific assertions here + + +# ... more decorator tests ... + + +class MockSchema(BaseModel): + """Mock schema for testing args_schema.""" + + arg: str + + +# Test suite starts here +class TestTool: + # Basic Functionality Tests + def test_tool_with_valid_callable_creates_base_tool(self, mock_func): + result = tool(mock_func) + assert isinstance(result, BaseTool) + + def test_tool_returns_correct_function_name(self, mock_func): + result = tool(mock_func) + assert result.func.__name__ == "mock_function" + + # Argument Handling Tests + def test_tool_with_string_and_runnable(self, mock_runnable): + result = tool("mock_runnable", mock_runnable) + assert isinstance(result, BaseTool) + + def test_tool_raises_error_with_invalid_arguments(self): + with pytest.raises(ValueError): + tool(123) + + # Schema Inference and Application Tests + def test_tool_with_args_schema(self, mock_func): + result = tool(mock_func, args_schema=MockSchema) + assert result.args_schema == MockSchema + + def test_tool_with_infer_schema_true(self, mock_func): + tool(mock_func, infer_schema=True) + # Assertions related to schema inference + + # Return Direct Feature Tests + def test_tool_with_return_direct_true(self, mock_func): + tool(mock_func, return_direct=True) + # Assertions for return_direct behavior + + # Error Handling Tests + def test_tool_raises_error_without_docstring(self): + def no_doc_func(arg: str) -> str: + return arg + + with pytest.raises(ValueError): + tool(no_doc_func) + + def test_tool_raises_error_runnable_without_object_schema(self, mock_runnable): + with pytest.raises(ValueError): + tool(mock_runnable) + + # Decorator Behavior Tests + @pytest.mark.asyncio + async def test_async_tool_function(self): + @tool + async def async_func(arg: str) -> str: + return arg + + # Assertions for async behavior + + # Integration with StructuredTool and Tool Classes + def test_integration_with_structured_tool(self, mock_func): + result = tool(mock_func) + assert isinstance(result, StructuredTool) + + # Concurrency and Async Handling Tests + def test_concurrency_in_tool(self, mock_func): + # Test related to concurrency + pass + + # Mocking and Isolation Tests + def test_mocking_external_dependencies(self, mocker): + # Use mocker to mock external dependencies + pass + + def test_tool_with_different_return_types(self): + @tool + def return_int(arg: str) -> int: + return int(arg) + + result = return_int("123") + assert isinstance(result, int) + assert result == 123 + + @tool + def return_bool(arg: str) -> bool: + return arg.lower() in ["true", "yes"] + + result = return_bool("true") + assert isinstance(result, bool) + assert result is True + + # Test with multiple arguments + def test_tool_with_multiple_args(self): + @tool + def concat_strings(a: str, b: str) -> str: + return a + b + + result = concat_strings("Hello", "World") + assert result == "HelloWorld" + + # Test handling of optional arguments + def test_tool_with_optional_args(self): + @tool + def greet(name: str, greeting: str = "Hello") -> str: + return f"{greeting} {name}" + + assert greet("Alice") == "Hello Alice" + assert greet("Alice", greeting="Hi") == "Hi Alice" + + # Test with variadic arguments + def test_tool_with_variadic_args(self): + @tool + def sum_numbers(*numbers: int) -> int: + return sum(numbers) + + assert sum_numbers(1, 2, 3) == 6 + assert sum_numbers(10, 20) == 30 + + # Test with keyword arguments + def test_tool_with_kwargs(self): + @tool + def build_query(**kwargs) -> str: + return "&".join(f"{k}={v}" for k, v in kwargs.items()) + + assert build_query(a=1, b=2) == "a=1&b=2" + assert build_query(foo="bar") == "foo=bar" + + # Test with mixed types of arguments + def test_tool_with_mixed_args(self): + @tool + def mixed_args(a: int, b: str, *args, **kwargs) -> str: + return f"{a}{b}{len(args)}{'-'.join(kwargs.values())}" + + assert mixed_args(1, "b", "c", "d", x="y", z="w") == "1b2y-w" + + # Test error handling with incorrect types + def test_tool_error_with_incorrect_types(self): + @tool + def add_numbers(a: int, b: int) -> int: + return a + b + + with pytest.raises(TypeError): + add_numbers("1", "2") + + # Test with nested tools + def test_nested_tools(self): + @tool + def inner_tool(arg: str) -> str: + return f"Inner {arg}" + + @tool + def outer_tool(arg: str) -> str: + return f"Outer {inner_tool(arg)}" + + assert outer_tool("Test") == "Outer Inner Test" + + def test_tool_with_global_variable(self): + @tool + def access_global(arg: str) -> str: + return f"{global_var} {arg}" + + assert access_global("Var") == "global Var" + + # Test with environment variables + def test_tool_with_env_variables(self, monkeypatch): + monkeypatch.setenv("TEST_VAR", "Environment") + + @tool + def access_env_variable(arg: str) -> str: + import os + + return f"{os.environ['TEST_VAR']} {arg}" + + assert access_env_variable("Var") == "Environment Var" + + # ... [Previous test cases] ... + + # Test with complex data structures + def test_tool_with_complex_data_structures(self): + @tool + def process_data(data: dict) -> list: + return [data[key] for key in sorted(data.keys())] + + result = process_data({"b": 2, "a": 1}) + assert result == [1, 2] + + # Test handling exceptions within the tool function + def test_tool_handling_internal_exceptions(self): + @tool + def function_that_raises(arg: str): + if arg == "error": + raise ValueError("Error occurred") + return arg + + with pytest.raises(ValueError): + function_that_raises("error") + assert function_that_raises("ok") == "ok" + + # Test with functions returning None + def test_tool_with_none_return(self): + @tool + def return_none(arg: str): + return None + + assert return_none("anything") is None + + # Test with lambda functions + def test_tool_with_lambda(self): + tool_lambda = tool(lambda x: x * 2) + assert tool_lambda(3) == 6 + + # Test with class methods + def test_tool_with_class_method(self): + class MyClass: + @tool + def method(self, arg: str) -> str: + return f"Method {arg}" + + obj = MyClass() + assert obj.method("test") == "Method test" + + # Test tool function with inheritance + def test_tool_with_inheritance(self): + class Parent: + @tool + def parent_method(self, arg: str) -> str: + return f"Parent {arg}" + + class Child(Parent): + @tool + def child_method(self, arg: str) -> str: + return f"Child {arg}" + + child_obj = Child() + assert child_obj.parent_method("test") == "Parent test" + assert child_obj.child_method("test") == "Child test" + + # Test with decorators stacking + def test_tool_with_multiple_decorators(self): + def another_decorator(func): + def wrapper(*args, **kwargs): + return f"Decorated {func(*args, **kwargs)}" + + return wrapper + + @tool + @another_decorator + def decorated_function(arg: str): + return f"Function {arg}" + + assert decorated_function("test") == "Decorated Function test" + + # Test tool function when used in a multi-threaded environment + def test_tool_in_multithreaded_environment(self): + import threading + + @tool + def threaded_function(arg: int) -> int: + return arg * 2 + + results = [] + + def thread_target(): + results.append(threaded_function(5)) + + threads = [threading.Thread(target=thread_target) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert results == [10] * 10 + + # Test with recursive functions + def test_tool_with_recursive_function(self): + @tool + def recursive_function(n: int) -> int: + if n == 0: + return 0 + else: + return n + recursive_function(n - 1) + + assert recursive_function(5) == 15 + + +# Additional tests can be added here to cover more scenarios diff --git a/tests/utils/subprocess_code_interpreter.py b/tests/utils/subprocess_code_interpreter.py new file mode 100644 index 000000000..601f8a09e --- /dev/null +++ b/tests/utils/subprocess_code_interpreter.py @@ -0,0 +1,246 @@ +import time +import threading +import pytest +import subprocess +from swarms.utils.code_interpreter import BaseCodeInterpreter, SubprocessCodeInterpreter + + +@pytest.fixture +def subprocess_code_interpreter(): + interpreter = SubprocessCodeInterpreter() + interpreter.start_cmd = "python -c" + yield interpreter + interpreter.terminate() + + +def test_base_code_interpreter_init(): + interpreter = BaseCodeInterpreter() + assert isinstance(interpreter, BaseCodeInterpreter) + + +def test_base_code_interpreter_run_not_implemented(): + interpreter = BaseCodeInterpreter() + with pytest.raises(NotImplementedError): + interpreter.run("code") + + +def test_base_code_interpreter_terminate_not_implemented(): + interpreter = BaseCodeInterpreter() + with pytest.raises(NotImplementedError): + interpreter.terminate() + + +def test_subprocess_code_interpreter_init(subprocess_code_interpreter): + assert isinstance(subprocess_code_interpreter, SubprocessCodeInterpreter) + + +def test_subprocess_code_interpreter_start_process(subprocess_code_interpreter): + subprocess_code_interpreter.start_process() + assert subprocess_code_interpreter.process is not None + + +def test_subprocess_code_interpreter_terminate(subprocess_code_interpreter): + subprocess_code_interpreter.start_process() + subprocess_code_interpreter.terminate() + assert subprocess_code_interpreter.process.poll() is not None + + +def test_subprocess_code_interpreter_run_success(subprocess_code_interpreter): + code = 'print("Hello, World!")' + result = list(subprocess_code_interpreter.run(code)) + assert any("Hello, World!" in output.get("output", "") for output in result) + + +def test_subprocess_code_interpreter_run_with_error(subprocess_code_interpreter): + code = 'print("Hello, World")\nraise ValueError("Error!")' + result = list(subprocess_code_interpreter.run(code)) + assert any("Error!" in output.get("output", "") for output in result) + + +def test_subprocess_code_interpreter_run_with_keyboard_interrupt( + subprocess_code_interpreter, +): + code = 'import time\ntime.sleep(2)\nprint("Hello, World")\nraise KeyboardInterrupt' + result = list(subprocess_code_interpreter.run(code)) + assert any("KeyboardInterrupt" in output.get("output", "") for output in result) + + +def test_subprocess_code_interpreter_run_max_retries( + subprocess_code_interpreter, monkeypatch +): + def mock_subprocess_popen(*args, **kwargs): + raise subprocess.CalledProcessError(1, "mocked_cmd") + + monkeypatch.setattr(subprocess, "Popen", mock_subprocess_popen) + + code = 'print("Hello, World!")' + result = list(subprocess_code_interpreter.run(code)) + assert any( + "Maximum retries reached. Could not execute code." in output.get("output", "") + for output in result + ) + + +def test_subprocess_code_interpreter_run_retry_on_error( + subprocess_code_interpreter, monkeypatch +): + def mock_subprocess_popen(*args, **kwargs): + nonlocal popen_count + if popen_count == 0: + popen_count += 1 + raise subprocess.CalledProcessError(1, "mocked_cmd") + else: + return subprocess.Popen( + "echo 'Hello, World!'", + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + monkeypatch.setattr(subprocess, "Popen", mock_subprocess_popen) + popen_count = 0 + + code = 'print("Hello, World!")' + result = list(subprocess_code_interpreter.run(code)) + assert any("Hello, World!" in output.get("output", "") for output in result) + + +# Add more tests to cover other aspects of the code and edge cases as needed + +# Import statements and fixtures from the previous code block + + +def test_subprocess_code_interpreter_line_postprocessor(subprocess_code_interpreter): + line = "This is a test line" + processed_line = subprocess_code_interpreter.line_postprocessor(line) + assert processed_line == line # No processing, should remain the same + + +def test_subprocess_code_interpreter_preprocess_code(subprocess_code_interpreter): + code = 'print("Hello, World!")' + preprocessed_code = subprocess_code_interpreter.preprocess_code(code) + assert preprocessed_code == code # No preprocessing, should remain the same + + +def test_subprocess_code_interpreter_detect_active_line(subprocess_code_interpreter): + line = "Active line: 5" + active_line = subprocess_code_interpreter.detect_active_line(line) + assert active_line == 5 + + +def test_subprocess_code_interpreter_detect_end_of_execution( + subprocess_code_interpreter, +): + line = "Execution completed." + end_of_execution = subprocess_code_interpreter.detect_end_of_execution(line) + assert end_of_execution is True + + +def test_subprocess_code_interpreter_run_debug_mode( + subprocess_code_interpreter, capsys +): + subprocess_code_interpreter.debug_mode = True + code = 'print("Hello, World!")' + result = list(subprocess_code_interpreter.run(code)) + captured = capsys.readouterr() + assert "Running code:\n" in captured.out + assert "Received output line:\n" in captured.out + + +def test_subprocess_code_interpreter_run_no_debug_mode( + subprocess_code_interpreter, capsys +): + subprocess_code_interpreter.debug_mode = False + code = 'print("Hello, World!")' + result = list(subprocess_code_interpreter.run(code)) + captured = capsys.readouterr() + assert "Running code:\n" not in captured.out + assert "Received output line:\n" not in captured.out + + +def test_subprocess_code_interpreter_run_empty_output_queue( + subprocess_code_interpreter, +): + code = 'print("Hello, World!")' + result = list(subprocess_code_interpreter.run(code)) + assert not any("active_line" in output for output in result) + + +def test_subprocess_code_interpreter_handle_stream_output_stdout( + subprocess_code_interpreter, +): + line = "This is a test line" + subprocess_code_interpreter.handle_stream_output(threading.current_thread(), False) + subprocess_code_interpreter.process.stdout.write(line + "\n") + subprocess_code_interpreter.process.stdout.flush() + time.sleep(0.1) + output = subprocess_code_interpreter.output_queue.get() + assert output["output"] == line + + +def test_subprocess_code_interpreter_handle_stream_output_stderr( + subprocess_code_interpreter, +): + line = "This is an error line" + subprocess_code_interpreter.handle_stream_output(threading.current_thread(), True) + subprocess_code_interpreter.process.stderr.write(line + "\n") + subprocess_code_interpreter.process.stderr.flush() + time.sleep(0.1) + output = subprocess_code_interpreter.output_queue.get() + assert output["output"] == line + + +def test_subprocess_code_interpreter_run_with_preprocess_code( + subprocess_code_interpreter, capsys +): + code = 'print("Hello, World!")' + subprocess_code_interpreter.preprocess_code = ( + lambda x: x.upper() + ) # Modify code in preprocess_code + result = list(subprocess_code_interpreter.run(code)) + assert any("Hello, World!" in output.get("output", "") for output in result) + + +def test_subprocess_code_interpreter_run_with_exception( + subprocess_code_interpreter, capsys +): + code = 'print("Hello, World!")' + subprocess_code_interpreter.start_cmd = ( + "nonexistent_command" # Force an exception during subprocess creation + ) + result = list(subprocess_code_interpreter.run(code)) + assert any( + "Maximum retries reached" in output.get("output", "") for output in result + ) + + +def test_subprocess_code_interpreter_run_with_active_line( + subprocess_code_interpreter, capsys +): + code = "a = 5\nprint(a)" # Contains an active line + result = list(subprocess_code_interpreter.run(code)) + assert any(output.get("active_line") == 5 for output in result) + + +def test_subprocess_code_interpreter_run_with_end_of_execution( + subprocess_code_interpreter, capsys +): + code = 'print("Hello, World!")' # Simple code without active line marker + result = list(subprocess_code_interpreter.run(code)) + assert any(output.get("active_line") is None for output in result) + + +def test_subprocess_code_interpreter_run_with_multiple_lines( + subprocess_code_interpreter, capsys +): + code = "a = 5\nb = 10\nprint(a + b)" + result = list(subprocess_code_interpreter.run(code)) + assert any("15" in output.get("output", "") for output in result) + + +def test_subprocess_code_interpreter_run_with_unicode_characters( + subprocess_code_interpreter, capsys +): + code = 'print("γ“γ‚“γ«γ‘γ―γ€δΈ–η•Œ")' # Contains unicode characters + result = list(subprocess_code_interpreter.run(code)) + assert any("γ“γ‚“γ«γ‘γ―γ€δΈ–η•Œ" in output.get("output", "") for output in result)