diff --git a/ROADMAP.md b/ROADMAP.md new file mode 100644 index 0000000..f7de9c1 --- /dev/null +++ b/ROADMAP.md @@ -0,0 +1,55 @@ +# Updated 7-Day SAA Implementation Roadmap + +## Day 1: SAAs Workers Integration + +- [x] Implement SAAsWorkers class in a new file `src/workers.py` +- [x] Add asynchronous execution capabilities for parallel processing +- [x] Integrate SAAsWorkers with the existing Orchestrator class + +## Day 2: Enhance Main Assistant and Prompts + +- [ ] Improve the `_generate_main_prompt` method in `src/orchestrator.py` for better task breakdown +- [ ] Implement logic for the MAIN_ASSISTANT to handle tasks without subtask decomposition +- [ ] Enhance prompt writing capabilities for SUB_ASSISTANT tasks + +## Day 3: Implement Parallel Research Use Case + +- [ ] Develop a parallel research system in `src/use_cases/research.py` +- [ ] Integrate the research use case with SAAsWorkers +- [ ] Add necessary tools for web scraping and data processing + +## Day 4: CLI Enhancements and Basic UI + +- [ ] Extend the CLI in `src/main.py` to support new SAAsWorkers functionality +- [ ] Implement a basic Streamlit UI for interaction (instead of Sveltekit for time constraints) +- [ ] Create commands for the research use case + +## Day 5: Testing and Error Handling + +- [ ] Update existing tests in `tests/test_orchestrator.py` for new functionality +- [ ] Add new tests for SAAsWorkers and the research use case +- [ ] Enhance error handling and logging throughout the project + +## Day 6: Documentation and Examples + +- [ ] Update the README.md with new features and usage instructions +- [ ] Create example scripts for the research use case +- [ ] Document the SAAsWorkers implementation and integration + +## Day 7: Optimization and Final Testing + +- [ ] Optimize parallel execution in SAAsWorkers +- [ ] Conduct end-to-end testing of the entire system +- [ ] Address any remaining bugs or issues +- [ ] Prepare for deployment (if applicable) + +# Backlog (Future Development) + +1. Implement additional use cases (e.g., content creation, autocomplete) +2. Develop a plugin system for easy integration of new tools +3. Enhance the configuration management system +4. Implement a full-featured FastAPI-based API +5. Develop strategies for handling larger workloads and scaling +6. Integrate additional Phidata tools and features +7. Implement long-term memory and knowledge base systems +8. Create a more advanced UI with data visualization capabilities diff --git a/requirements-dev.txt b/requirements-dev.txt index 8ad51d6..9be122e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -4,12 +4,13 @@ # Testing pytest==8.2.2 pytest-mock==3.14.0 +pytest-asyncio # Type checking -mypy==1.7.1 +mypy==1.7.1 # Debugging -ipdb==0.13.13 +ipdb==0.13.13 # Security bandit==1.7.5 @@ -27,4 +28,4 @@ python-dotenv==1.0.1 # Already in requirements.txt, but included here for compl typer[all]==0.12.3 # Already in requirements.txt, but included here for completeness # Code Formating black==24.4.2 -isort==5.12.0 \ No newline at end of file +isort==5.12.0 diff --git a/requirements.txt b/requirements.txt index 8d082a2..2fa74a6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,7 +25,7 @@ numpy==2.0.0 python-multipart==0.0.9 # Async support -anyio==4.4.0 +asyncio # CLI enhancements click==8.1.7 @@ -34,8 +34,8 @@ shellingham==1.5.4 # Time handling python-dateutil==2.9.0.post0 -# pgvector -# pypdf -# psycopg2-binary -# sqlalchemy -# fastapi \ No newline at end of file +# pgvector +# pypdf +# psycopg2-binary +# sqlalchemy +# fastapi diff --git a/src/__init__.py b/src/__init__.py index 0744385..ed72a5b 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,14 +1,16 @@ -from .assistants import get_full_response, main_assistant, refiner_assistant, sub_assistant +from .assistants import create_assistant, get_full_response from .config import settings from .orchestrator import Orchestrator, Task, TaskExchange +from .workers import PlanResponse, SAAsWorkers, WorkerTask __all__ = [ "get_full_response", - "main_assistant", - "refiner_assistant", - "sub_assistant", + "create_assistant", "settings", "Orchestrator", "Task", "TaskExchange", + "SAAsWorkers", + "WorkerTask", + "PlanResponse", ] diff --git a/src/assistants.py b/src/assistants.py index 8a75c5e..df97594 100644 --- a/src/assistants.py +++ b/src/assistants.py @@ -74,6 +74,7 @@ def create_assistant(name: str, model: str): read_file, list_files, ], + debug_mode=True, ) except Exception as e: logger.error(f"Error creating assistant {name} with model {model}: {str(e)}") diff --git a/src/config.py b/src/config.py index 7da2b77..e129df3 100644 --- a/src/config.py +++ b/src/config.py @@ -38,17 +38,20 @@ def tavily_api_key(self) -> str: OPENAI_API_KEY: Optional[str] = os.getenv("OPENAI_API_KEY") # Assistant settings - MAIN_ASSISTANT: str = "claude-3-5-sonnet-20240620" - SUB_ASSISTANT: str = "gpt-3.5-turbo" + MAIN_ASSISTANT: str = "claude-3-sonnet-20240229" + SUB_ASSISTANT: str = "claude-3-haiku-20240307" REFINER_ASSISTANT: str = "gemini-1.5-pro-preview-0409" # Fallback models - FALLBACK_MODEL_1: str = "claude-3-sonnet-20240229" + FALLBACK_MODEL_1: str = "gpt-3.5-turbo" FALLBACK_MODEL_2: str = "gpt-3.5-turbo" # Tools TAVILY_API_KEY: Optional[str] = os.getenv("TAVILY_API_KEY") + # New setting for SAAsWorkers + NUM_WORKERS: int = 3 + class Config: env_file = ".env" extra = "ignore" # This will ignore any extra fields in the environment diff --git a/src/main.py b/src/main.py index 9ce90e4..0c529df 100644 --- a/src/main.py +++ b/src/main.py @@ -1,3 +1,5 @@ +import asyncio + import typer from rich import print as rprint @@ -19,7 +21,7 @@ def run_workflow( try: rprint("[bold]Starting SAA Orchestrator[/bold]") orchestrator = Orchestrator() - result = orchestrator.run_workflow(full_objective) + result = asyncio.run(orchestrator.run_workflow(full_objective)) rprint("\n[bold green]Workflow completed![/bold green]") rprint("\n[bold]Final Output:[/bold]") diff --git a/src/orchestrator.py b/src/orchestrator.py index 023353d..ccef919 100644 --- a/src/orchestrator.py +++ b/src/orchestrator.py @@ -1,13 +1,13 @@ -import json import os from typing import Any, Dict, List, Literal -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field -from .assistants import create_assistant, get_full_response +from .assistants import create_assistant from .config import settings from .utils.exceptions import AssistantError, WorkflowError from .utils.logging import setup_logging +from .workers import PlanResponse, SAAsWorkers logger = setup_logging() @@ -19,10 +19,11 @@ class TaskExchange(BaseModel): class Task(BaseModel): task: str + prompt: str result: str def to_dict(self) -> Dict[str, Any]: - return {"task": str(self.task), "result": str(self.result)} + return {"task": str(self.task), "prompt": str(self.prompt), "result": str(self.result)} class State(BaseModel): @@ -39,14 +40,17 @@ def to_dict(self) -> Dict[str, Any]: class Orchestrator(BaseModel): state: State = State() output_dir: str = Field(default_factory=lambda: os.path.join(os.getcwd(), "output")) + workers: SAAsWorkers = Field(default_factory=lambda: SAAsWorkers(settings.NUM_WORKERS)) + + model_config = ConfigDict(arbitrary_types_allowed=True) def __init__(self, **data): super().__init__(**data) os.makedirs(self.output_dir, exist_ok=True) - def run_workflow(self, objective: str) -> str: + async def run_workflow(self, objective: str) -> str: """ - Executes the workflow to accomplish the given objective. + Executes the workflow to accomplish the given objective using SAAsWorkers. Args: objective (str): The main task or goal to be accomplished. @@ -54,31 +58,52 @@ def run_workflow(self, objective: str) -> str: Returns: str: The final refined output of the workflow. """ - logger.info(f"Starting workflow with objective: {objective}") self.state.task_exchanges.append(TaskExchange(role="user", content=objective)) - task_counter = 1 - try: - while True: - logger.info(f"Starting task {task_counter}") - main_prompt = self._generate_main_prompt(objective) - main_response = self._get_assistant_response("main", main_prompt) - - if main_response.startswith("ALL DONE:"): - logger.info("Workflow completed") - break - - sub_task_prompt = self._generate_sub_task_prompt(main_response) - sub_response = self._get_assistant_response("sub", sub_task_prompt) - self.state.tasks.append(Task(task=main_response, result=sub_response)) - task_counter += 1 - - refined_output = self._get_refined_output(objective) - self._save_exchange_log(objective, refined_output) - logger.info("Exchange log saved") + try: + main_assistant = create_assistant("MainAssistant", settings.MAIN_ASSISTANT) + refiner_assistant = create_assistant("RefinerAssistant", settings.REFINER_ASSISTANT) + + # Plan tasks + plan_result: PlanResponse = await self.workers.plan_tasks(objective, main_assistant) + + if plan_result.objective_completion: + # Single-task scenario + final_output = plan_result.explanation + self.state.task_exchanges.append( + TaskExchange(role="main_assistant", content=final_output) + ) + else: + # Multi-task scenario + tasks = plan_result.tasks if plan_result.tasks else [] + self.state.task_exchanges.append( + TaskExchange(role="main_assistant", content="\n".join([t.task for t in tasks])) + ) + + # Process tasks + results = await self.workers.process_tasks(tasks) + for result in results: + self.state.tasks.append( + Task(task=result.task, prompt=result.prompt, result=result.result) + ) + self.state.task_exchanges.append( + TaskExchange(role="sub_assistant", content=result.result) + ) + + # Summarize results + final_output = await self.workers.summarize_results( + objective, results, refiner_assistant + ) + self.state.task_exchanges.append( + TaskExchange(role="refiner_assistant", content=final_output) + ) + + self._save_exchange_log(objective, final_output) + logger.info("Workflow completed and exchange log saved") + + return final_output - return refined_output except AssistantError as e: logger.error(f"Assistant error: {str(e)}") raise WorkflowError(f"Workflow failed due to assistant error: {str(e)}") @@ -86,48 +111,6 @@ def run_workflow(self, objective: str) -> str: logger.exception("Unexpected error in workflow execution") raise WorkflowError(f"Unexpected error in workflow execution: {str(e)}") - def _generate_main_prompt(self, objective: str) -> str: - return ( - f"Objective: {objective}\n\n" - f"Current progress:\n{json.dumps(self.state.to_dict(), indent=2)}\n\n" - "Break down this objective into the next specific sub-task, or if the objective is fully achieved, " - "start your response with 'ALL DONE:' followed by the final output." - ) - - def _generate_sub_task_prompt(self, main_response: str) -> str: - return ( - f"Previous tasks: {json.dumps([task.to_dict() for task in self.state.tasks], indent=2)}\n\n" - f"Current task: {main_response}\n\n" - "Execute this task and provide the result. Use the provided functions to create, read, or list files as needed. " - f"All file operations should be relative to the '{self.output_dir}' directory." - ) - - def _get_assistant_response(self, assistant_type: str, prompt: str) -> str: - try: - assistant_model = getattr(settings, f"{assistant_type.upper()}_ASSISTANT") - assistant = create_assistant(f"{assistant_type.capitalize()}Assistant", assistant_model) - response = get_full_response(assistant, prompt) - logger.info(f"{assistant_type.capitalize()} assistant response received") - self.state.task_exchanges.append( - TaskExchange(role=f"{assistant_type}_assistant", content=response) - ) - return response - except Exception as e: - logger.error(f"Error getting response from {assistant_type} assistant: {str(e)}") - raise AssistantError( - f"Error getting response from {assistant_type} assistant: {str(e)}" - ) - - def _get_refined_output(self, objective: str) -> str: - refiner_prompt = ( - f"Original objective: {objective}\n\n" - f"Task breakdown and results: {json.dumps([task.to_dict() for task in self.state.tasks], indent=2)}\n\n" - "Please refine these results into a coherent final output, summarizing the project structure created. " - f"You can use the provided functions to list and read files if needed. All files are in the '{self.output_dir}' directory. " - "Provide your response as a string, not a list or dictionary." - ) - return self._get_assistant_response("refiner", refiner_prompt) - def _save_exchange_log(self, objective: str, final_output: str): """ Saves the workflow exchange log to a markdown file. diff --git a/src/utils/exceptions.py b/src/utils/exceptions.py index e75fb64..4634e7f 100644 --- a/src/utils/exceptions.py +++ b/src/utils/exceptions.py @@ -16,3 +16,7 @@ class ConfigurationError(SAAOrchestratorError): class PluginError(SAAOrchestratorError): """Raised when there's an error with a plugin""" + + +class WorkerError(Exception): + """Base exception class for SAA Workers""" diff --git a/src/workers.py b/src/workers.py new file mode 100644 index 0000000..d1663e0 --- /dev/null +++ b/src/workers.py @@ -0,0 +1,128 @@ +import asyncio +import json +from typing import List, Optional + +from phi.assistant import Assistant +from pydantic import BaseModel, Field + +from src.assistants import create_assistant, get_full_response +from src.config import settings +from src.utils.exceptions import WorkerError +from src.utils.logging import setup_logging + +logger = setup_logging() + + +class WorkerTask(BaseModel): + task: str = Field(..., description="Brief description of the task") + prompt: str = Field(..., description="Detailed prompt for the worker to accomplish the task") + result: Optional[str] = Field(None, description="Result of the task execution") + + +class PlanResponse(BaseModel): + objective_completion: bool = Field( + ..., description="Whether the objective can be completed without subtasks" + ) + explanation: str = Field( + ..., description="Explanation or direct response if objective_completion is True" + ) + tasks: Optional[List[WorkerTask]] = Field( + None, description="List of tasks if objective_completion is False" + ) + + +class SAAsWorkers: + def __init__(self, num_workers: int = 3): + self.num_workers = num_workers + self.workers = [ + create_assistant(f"Worker{i}", settings.SUB_ASSISTANT) for i in range(num_workers) + ] + + async def execute_task(self, worker: Assistant, task: WorkerTask) -> str: + try: + return await asyncio.to_thread(get_full_response, worker, task.prompt) + except Exception as e: + logger.error(f"Error executing task: {str(e)}") + raise WorkerError(f"Error executing task: {str(e)}") + + async def process_tasks(self, tasks: List[WorkerTask]) -> List[WorkerTask]: + worker_tasks = [] + for task, worker in zip(tasks, self.workers): + worker_tasks.append(self.execute_task(worker, task)) + + results = await asyncio.gather(*worker_tasks, return_exceptions=True) + + processed_tasks = [] + for task, result in zip(tasks, results): + if isinstance(result, Exception): + logger.error(f"Task failed: {task.task}. Error: {str(result)}") + task.result = f"Error: {str(result)}" + else: + task.result = result + processed_tasks.append(task) + + return processed_tasks + + @staticmethod + async def plan_tasks(objective: str, main_assistant: Assistant) -> PlanResponse: + plan_prompt = f""" + Analyze the following objective and determine if it requires subtask decomposition: + + Objective: {objective} + + Respond with a JSON object that follows this structure: + + {{ + "objective_completion": boolean, + "explanation": string, + "tasks": [ + {{ + "task": string, + "prompt": string + }}, + ... + ] + }} + + If the objective can be accomplished without subtask decomposition: + - Set "objective_completion" to true + - Provide a concise solution or response to the objective in the "explanation" field + - Leave the "tasks" array empty + + If the objective requires subtask decomposition: + - Set "objective_completion" to false + - Provide a brief explanation in the "explanation" field + - Break down the objective into {settings.NUM_WORKERS} subtasks in the "tasks" array + - For each subtask, include a "task" field with a brief description and a "prompt" field with detailed instructions + + Remember, you are a skilled prompt engineer. Create prompts that are clear, specific, and actionable. + """ + + planner = Assistant( + name="TaskPlanner", + llm=main_assistant.llm, + description="You are a task planner that analyzes objectives and breaks them down into subtasks if necessary.", + ) + + response = await asyncio.to_thread(get_full_response, planner, plan_prompt) + + try: + plan_dict = json.loads(response) + return PlanResponse(**plan_dict) + except json.JSONDecodeError: + logger.error(f"Failed to parse JSON response: {response}") + raise WorkerError("Failed to parse plan response as JSON") + except ValueError as e: + logger.error(f"Invalid plan response structure: {str(e)}") + raise WorkerError(f"Invalid plan response structure: {str(e)}") + + @staticmethod + async def summarize_results( + objective: str, results: List[WorkerTask], refiner_assistant: Assistant + ) -> str: + summary_prompt = f"Objective: {objective}\n\nTask results:\n" + for task in results: + summary_prompt += f"Task: {task.task}\nResult: {task.result}\n\n" + summary_prompt += "Please summarize these results into a coherent final output that addresses the original objective." + + return await asyncio.to_thread(get_full_response, refiner_assistant, summary_prompt) diff --git a/tests/test_orchestrator.py b/tests/test_orchestrator.py index 21d246c..87e751e 100644 --- a/tests/test_orchestrator.py +++ b/tests/test_orchestrator.py @@ -1,10 +1,11 @@ import os -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, patch import pytest from src.orchestrator import Orchestrator, Task, TaskExchange from src.utils.exceptions import WorkflowError +from src.workers import PlanResponse, WorkerTask @pytest.fixture @@ -27,55 +28,80 @@ def test_task_exchange(): def test_task(): - task = Task(task="Test task", result="Test result") + task = Task(task="Test task", prompt="Test prompt", result="Test result") assert task.task == "Test task" assert task.result == "Test result" - assert task.to_dict() == {"task": "Test task", "result": "Test result"} + assert task.to_dict() == {"task": "Test task", "prompt": "Test prompt", "result": "Test result"} +@pytest.mark.asyncio @patch("src.orchestrator.create_assistant") -@patch("src.orchestrator.get_full_response") -def test_run_workflow(mock_get_full_response, mock_create_assistant, orchestrator): - mock_assistant = MagicMock() - mock_create_assistant.return_value = mock_assistant - - mock_get_full_response.side_effect = [ - "First sub-task", - "Result of first sub-task", - "ALL DONE: Final output", - "Refined output", - ] +async def test_run_workflow_single_task(mock_create_assistant, orchestrator): + mock_main_assistant = AsyncMock() + mock_refiner_assistant = AsyncMock() + mock_create_assistant.side_effect = [mock_main_assistant, mock_refiner_assistant] + + mock_plan_response = PlanResponse( + objective_completion=True, explanation="This is a simple task.", tasks=None + ) + orchestrator.workers.plan_tasks = AsyncMock(return_value=mock_plan_response) - result = orchestrator.run_workflow("Test objective") + result = await orchestrator.run_workflow("Test objective") - assert isinstance(result, str) - assert "Refined output" in result + assert result == "This is a simple task." + assert len(orchestrator.state.task_exchanges) == 2 + assert orchestrator.state.task_exchanges[0].role == "user" + assert orchestrator.state.task_exchanges[1].role == "main_assistant" + +@pytest.mark.asyncio +@patch("src.orchestrator.create_assistant") +async def test_run_workflow_multiple_tasks(mock_create_assistant, orchestrator): + mock_main_assistant = AsyncMock() + mock_refiner_assistant = AsyncMock() + mock_create_assistant.side_effect = [mock_main_assistant, mock_refiner_assistant] + + mock_plan_response = PlanResponse( + objective_completion=False, + explanation="This task requires multiple steps.", + tasks=[ + WorkerTask(task="Subtask 1", prompt="Do subtask 1"), + WorkerTask(task="Subtask 2", prompt="Do subtask 2"), + ], + ) + orchestrator.workers.plan_tasks = AsyncMock(return_value=mock_plan_response) + orchestrator.workers.process_tasks = AsyncMock( + return_value=[ + WorkerTask(task="Subtask 1", prompt="Do subtask 1", result="Result 1"), + WorkerTask(task="Subtask 2", prompt="Do subtask 2", result="Result 2"), + ] + ) + orchestrator.workers.summarize_results = AsyncMock(return_value="Final summary") + + result = await orchestrator.run_workflow("Test objective") + + assert result == "Final summary" assert len(orchestrator.state.task_exchanges) == 5 - assert len(orchestrator.state.tasks) == 1 - - expected_roles = [ - "user", - "main_assistant", - "sub_assistant", - "main_assistant", - "refiner_assistant", - ] - actual_roles = [exchange.role for exchange in orchestrator.state.task_exchanges] - assert actual_roles == expected_roles + assert orchestrator.state.task_exchanges[0].role == "user" + assert orchestrator.state.task_exchanges[1].role == "main_assistant" + assert orchestrator.state.task_exchanges[2].role == "sub_assistant" + assert orchestrator.state.task_exchanges[3].role == "sub_assistant" + assert orchestrator.state.task_exchanges[4].role == "refiner_assistant" - assert mock_get_full_response.call_count == 4 - assert mock_create_assistant.call_count == 4 - assert os.path.exists(os.path.join(orchestrator.output_dir, "exchange_log.md")) +@pytest.mark.asyncio +@patch("src.orchestrator.create_assistant") +async def test_run_workflow_error(mock_create_assistant, orchestrator): + mock_main_assistant = AsyncMock() + mock_create_assistant.return_value = mock_main_assistant + orchestrator.workers.plan_tasks = AsyncMock(side_effect=Exception("API Error")) -@patch("builtins.open", new_callable=MagicMock) -@patch("os.path.join", return_value="mocked_path") -def test_save_exchange_log(mock_join, mock_open, orchestrator): - mock_file = MagicMock() - mock_open.return_value.__enter__.return_value = mock_file + with pytest.raises(WorkflowError): + await orchestrator.run_workflow("Test objective") + +def test_save_exchange_log(orchestrator): orchestrator.state.task_exchanges = [ TaskExchange(role="user", content="Test objective"), TaskExchange(role="main_assistant", content="Test response"), @@ -83,26 +109,21 @@ def test_save_exchange_log(mock_join, mock_open, orchestrator): orchestrator._save_exchange_log("Test objective", "Test output") - mock_open.assert_called_once_with("mocked_path", "w") - mock_file.write.assert_called() - - -@patch("src.orchestrator.get_full_response") -def test_run_workflow_error(mock_get_full_response, orchestrator): - mock_get_full_response.side_effect = Exception("API Error") - - with pytest.raises(WorkflowError): - orchestrator.run_workflow("Test objective") - + log_file_path = os.path.join(orchestrator.output_dir, "exchange_log.md") + assert os.path.exists(log_file_path) -def test_task_exchange_validation(): - with pytest.raises(ValueError): - TaskExchange(role="invalid_role", content="Test content") + with open(log_file_path, "r") as f: + content = f.read() + assert "Test objective" in content + assert "Test response" in content + assert "Test output" in content def test_state_to_dict(orchestrator): orchestrator.state.task_exchanges.append(TaskExchange(role="user", content="Test")) - orchestrator.state.tasks.append(Task(task="Test task", result="Test result")) + orchestrator.state.tasks.append( + Task(task="Test task", prompt="Test prompt", result="Test result") + ) state_dict = orchestrator.state.to_dict() assert len(state_dict["task_exchanges"]) == 1 diff --git a/tests/test_workers.py b/tests/test_workers.py new file mode 100644 index 0000000..3baface --- /dev/null +++ b/tests/test_workers.py @@ -0,0 +1,160 @@ +import json +from unittest.mock import AsyncMock, call, patch + +import pytest + +from src.utils.exceptions import WorkerError +from src.workers import PlanResponse, SAAsWorkers, WorkerTask + + +@pytest.fixture +def mock_assistant(): + return AsyncMock() + + +@pytest.fixture +def workers(): + return SAAsWorkers(num_workers=3) + + +@pytest.mark.asyncio +async def test_plan_tasks_single_task(workers, mock_assistant): + expected_response = PlanResponse( + objective_completion=True, + explanation="This is a simple task that doesn't require subtasks.", + tasks=None, + ) + json_response = json.dumps(expected_response.model_dump()) + + def check_plan(assistant, prompt): + assert "Simple objective" in prompt + return json_response + + with patch("src.workers.Assistant") as MockAssistant: + MockAssistant.return_value = AsyncMock() + with patch("src.workers.get_full_response", side_effect=check_plan): + result = await workers.plan_tasks("Simple objective", mock_assistant) + + assert isinstance(result, PlanResponse) + assert result.objective_completion == True + assert result.explanation == "This is a simple task that doesn't require subtasks." + assert result.tasks is None + MockAssistant.assert_called_once_with( + name="TaskPlanner", + llm=mock_assistant.llm, + description="You are a task planner that analyzes objectives and breaks them down into subtasks if necessary.", + ) + + +@pytest.mark.asyncio +async def test_plan_tasks_multiple_tasks(workers, mock_assistant): + expected_response = PlanResponse( + objective_completion=False, + explanation="This objective requires multiple subtasks.", + tasks=[ + WorkerTask(task="Subtask 1", prompt="Do subtask 1"), + WorkerTask(task="Subtask 2", prompt="Do subtask 2"), + WorkerTask(task="Subtask 3", prompt="Do subtask 3"), + ], + ) + json_response = json.dumps(expected_response.model_dump()) + + def check_plan(assistant, prompt): + assert "Complex objective" in prompt + return json_response + + with patch("src.workers.Assistant") as MockAssistant: + MockAssistant.return_value = AsyncMock() + with patch("src.workers.get_full_response", side_effect=check_plan): + result = await workers.plan_tasks("Complex objective", mock_assistant) + + assert isinstance(result, PlanResponse) + assert result.objective_completion == False + assert result.explanation == "This objective requires multiple subtasks." + assert len(result.tasks) == 3 + assert all(isinstance(task, WorkerTask) for task in result.tasks) + MockAssistant.assert_called_once_with( + name="TaskPlanner", + llm=mock_assistant.llm, + description="You are a task planner that analyzes objectives and breaks them down into subtasks if necessary.", + ) + + +@pytest.mark.asyncio +async def test_process_tasks(workers): + tasks = [ + WorkerTask(task="Task 1", prompt="Do task 1"), + WorkerTask(task="Task 2", prompt="Do task 2"), + WorkerTask(task="Task 3", prompt="Do task 3"), + ] + workers.execute_task = AsyncMock(side_effect=["Result 1", "Result 2", "Result 3"]) + results = await workers.process_tasks(tasks) + assert len(results) == 3 + assert all(task.result for task in results) + assert [task.result for task in results] == ["Result 1", "Result 2", "Result 3"] + + +@pytest.mark.asyncio +async def test_execute_task(workers): + worker = AsyncMock() + task = WorkerTask(task="Test task", prompt="Test prompt") + + def check_execute(assistant, prompt): + assert assistant == worker + assert prompt == "Test prompt" + return "Task result" + + with patch("src.workers.get_full_response", side_effect=check_execute): + result = await workers.execute_task(worker, task) + assert result == "Task result" + + +@pytest.mark.asyncio +async def test_summarize_results(workers, mock_assistant): + objective = "Test objective" + tasks = [ + WorkerTask(task="Task 1", prompt="Prompt 1", result="Result 1"), + WorkerTask(task="Task 2", prompt="Prompt 2", result="Result 2"), + ] + expected_summary = "Summary of results" + + def check_prompt(assistant, prompt): + assert assistant == mock_assistant + assert objective in prompt + assert all(task.task in prompt for task in tasks) + assert all(task.result in prompt for task in tasks) + return expected_summary + + with patch("src.workers.get_full_response", side_effect=check_prompt) as mock_get_full_response: + result = await workers.summarize_results(objective, tasks, mock_assistant) + + assert result == expected_summary + mock_get_full_response.assert_called_once() + + +@pytest.mark.asyncio +async def test_plan_tasks_error_handling(workers, mock_assistant): + with patch("src.workers.Assistant"): + with patch("src.workers.get_full_response", return_value="Invalid JSON"): + with pytest.raises(WorkerError): + await workers.plan_tasks("Test objective", mock_assistant) + + +@pytest.mark.asyncio +async def test_process_tasks_error_handling(workers): + tasks = [WorkerTask(task="Task 1", prompt="Prompt 1")] + workers.execute_task = AsyncMock(side_effect=Exception("Task execution error")) + + results = await workers.process_tasks(tasks) + assert len(results) == 1 + assert results[0].result.startswith("Error:") + + +@pytest.mark.asyncio +async def test_execute_task_error_handling(workers): + worker = AsyncMock() + task = WorkerTask(task="Test task", prompt="Test prompt") + + with patch("src.workers.get_full_response", side_effect=Exception("Execution error")): + with pytest.raises(WorkerError): + await workers.execute_task(worker, task)