Skip to content

Commit

Permalink
feat(workers): implement SAAsWorkers for parallel task execution
Browse files Browse the repository at this point in the history
- Add SAAsWorkers class for asynchronous execution of tasks
- Integrate SAAsWorkers with Orchestrator for workflow management
- Enhance Orchestrator to support single-task and multi-task scenarios
- Update tests for new functionality and error handling
- Modify requirements and configuration settings
  • Loading branch information
jeblister committed Jul 5, 2024
1 parent 54590f1 commit 6dd31f4
Show file tree
Hide file tree
Showing 12 changed files with 497 additions and 137 deletions.
55 changes: 55 additions & 0 deletions ROADMAP.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Updated 7-Day SAA Implementation Roadmap

## Day 1: SAAs Workers Integration

- [x] Implement SAAsWorkers class in a new file `src/workers.py`
- [x] Add asynchronous execution capabilities for parallel processing
- [x] Integrate SAAsWorkers with the existing Orchestrator class

## Day 2: Enhance Main Assistant and Prompts

- [ ] Improve the `_generate_main_prompt` method in `src/orchestrator.py` for better task breakdown
- [ ] Implement logic for the MAIN_ASSISTANT to handle tasks without subtask decomposition
- [ ] Enhance prompt writing capabilities for SUB_ASSISTANT tasks

## Day 3: Implement Parallel Research Use Case

- [ ] Develop a parallel research system in `src/use_cases/research.py`
- [ ] Integrate the research use case with SAAsWorkers
- [ ] Add necessary tools for web scraping and data processing

## Day 4: CLI Enhancements and Basic UI

- [ ] Extend the CLI in `src/main.py` to support new SAAsWorkers functionality
- [ ] Implement a basic Streamlit UI for interaction (instead of Sveltekit for time constraints)
- [ ] Create commands for the research use case

## Day 5: Testing and Error Handling

- [ ] Update existing tests in `tests/test_orchestrator.py` for new functionality
- [ ] Add new tests for SAAsWorkers and the research use case
- [ ] Enhance error handling and logging throughout the project

## Day 6: Documentation and Examples

- [ ] Update the README.md with new features and usage instructions
- [ ] Create example scripts for the research use case
- [ ] Document the SAAsWorkers implementation and integration

## Day 7: Optimization and Final Testing

- [ ] Optimize parallel execution in SAAsWorkers
- [ ] Conduct end-to-end testing of the entire system
- [ ] Address any remaining bugs or issues
- [ ] Prepare for deployment (if applicable)

# Backlog (Future Development)

1. Implement additional use cases (e.g., content creation, autocomplete)
2. Develop a plugin system for easy integration of new tools
3. Enhance the configuration management system
4. Implement a full-featured FastAPI-based API
5. Develop strategies for handling larger workloads and scaling
6. Integrate additional Phidata tools and features
7. Implement long-term memory and knowledge base systems
8. Create a more advanced UI with data visualization capabilities
7 changes: 4 additions & 3 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
# Testing
pytest==8.2.2
pytest-mock==3.14.0
pytest-asyncio

# Type checking
mypy==1.7.1
mypy==1.7.1

# Debugging
ipdb==0.13.13
ipdb==0.13.13

# Security
bandit==1.7.5
Expand All @@ -27,4 +28,4 @@ python-dotenv==1.0.1 # Already in requirements.txt, but included here for compl
typer[all]==0.12.3 # Already in requirements.txt, but included here for completeness
# Code Formating
black==24.4.2
isort==5.12.0
isort==5.12.0
12 changes: 6 additions & 6 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ numpy==2.0.0
python-multipart==0.0.9

# Async support
anyio==4.4.0
asyncio

# CLI enhancements
click==8.1.7
Expand All @@ -34,8 +34,8 @@ shellingham==1.5.4
# Time handling
python-dateutil==2.9.0.post0

# pgvector
# pypdf
# psycopg2-binary
# sqlalchemy
# fastapi
# pgvector
# pypdf
# psycopg2-binary
# sqlalchemy
# fastapi
10 changes: 6 additions & 4 deletions src/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from .assistants import get_full_response, main_assistant, refiner_assistant, sub_assistant
from .assistants import create_assistant, get_full_response
from .config import settings
from .orchestrator import Orchestrator, Task, TaskExchange
from .workers import PlanResponse, SAAsWorkers, WorkerTask

__all__ = [
"get_full_response",
"main_assistant",
"refiner_assistant",
"sub_assistant",
"create_assistant",
"settings",
"Orchestrator",
"Task",
"TaskExchange",
"SAAsWorkers",
"WorkerTask",
"PlanResponse",
]
1 change: 1 addition & 0 deletions src/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def create_assistant(name: str, model: str):
read_file,
list_files,
],
debug_mode=True,
)
except Exception as e:
logger.error(f"Error creating assistant {name} with model {model}: {str(e)}")
Expand Down
9 changes: 6 additions & 3 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,20 @@ def tavily_api_key(self) -> str:
OPENAI_API_KEY: Optional[str] = os.getenv("OPENAI_API_KEY")

# Assistant settings
MAIN_ASSISTANT: str = "claude-3-5-sonnet-20240620"
SUB_ASSISTANT: str = "gpt-3.5-turbo"
MAIN_ASSISTANT: str = "claude-3-sonnet-20240229"
SUB_ASSISTANT: str = "claude-3-haiku-20240307"
REFINER_ASSISTANT: str = "gemini-1.5-pro-preview-0409"

# Fallback models
FALLBACK_MODEL_1: str = "claude-3-sonnet-20240229"
FALLBACK_MODEL_1: str = "gpt-3.5-turbo"
FALLBACK_MODEL_2: str = "gpt-3.5-turbo"

# Tools
TAVILY_API_KEY: Optional[str] = os.getenv("TAVILY_API_KEY")

# New setting for SAAsWorkers
NUM_WORKERS: int = 3

class Config:
env_file = ".env"
extra = "ignore" # This will ignore any extra fields in the environment
Expand Down
4 changes: 3 additions & 1 deletion src/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio

import typer
from rich import print as rprint

Expand All @@ -19,7 +21,7 @@ def run_workflow(
try:
rprint("[bold]Starting SAA Orchestrator[/bold]")
orchestrator = Orchestrator()
result = orchestrator.run_workflow(full_objective)
result = asyncio.run(orchestrator.run_workflow(full_objective))

rprint("\n[bold green]Workflow completed![/bold green]")
rprint("\n[bold]Final Output:[/bold]")
Expand Down
121 changes: 52 additions & 69 deletions src/orchestrator.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import json
import os
from typing import Any, Dict, List, Literal

from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field

from .assistants import create_assistant, get_full_response
from .assistants import create_assistant
from .config import settings
from .utils.exceptions import AssistantError, WorkflowError
from .utils.logging import setup_logging
from .workers import PlanResponse, SAAsWorkers

logger = setup_logging()

Expand All @@ -19,10 +19,11 @@ class TaskExchange(BaseModel):

class Task(BaseModel):
task: str
prompt: str
result: str

def to_dict(self) -> Dict[str, Any]:
return {"task": str(self.task), "result": str(self.result)}
return {"task": str(self.task), "prompt": str(self.prompt), "result": str(self.result)}


class State(BaseModel):
Expand All @@ -39,95 +40,77 @@ def to_dict(self) -> Dict[str, Any]:
class Orchestrator(BaseModel):
state: State = State()
output_dir: str = Field(default_factory=lambda: os.path.join(os.getcwd(), "output"))
workers: SAAsWorkers = Field(default_factory=lambda: SAAsWorkers(settings.NUM_WORKERS))

model_config = ConfigDict(arbitrary_types_allowed=True)

def __init__(self, **data):
super().__init__(**data)
os.makedirs(self.output_dir, exist_ok=True)

def run_workflow(self, objective: str) -> str:
async def run_workflow(self, objective: str) -> str:
"""
Executes the workflow to accomplish the given objective.
Executes the workflow to accomplish the given objective using SAAsWorkers.
Args:
objective (str): The main task or goal to be accomplished.
Returns:
str: The final refined output of the workflow.
"""

logger.info(f"Starting workflow with objective: {objective}")
self.state.task_exchanges.append(TaskExchange(role="user", content=objective))
task_counter = 1
try:
while True:
logger.info(f"Starting task {task_counter}")
main_prompt = self._generate_main_prompt(objective)
main_response = self._get_assistant_response("main", main_prompt)

if main_response.startswith("ALL DONE:"):
logger.info("Workflow completed")
break

sub_task_prompt = self._generate_sub_task_prompt(main_response)
sub_response = self._get_assistant_response("sub", sub_task_prompt)

self.state.tasks.append(Task(task=main_response, result=sub_response))
task_counter += 1

refined_output = self._get_refined_output(objective)
self._save_exchange_log(objective, refined_output)
logger.info("Exchange log saved")
try:
main_assistant = create_assistant("MainAssistant", settings.MAIN_ASSISTANT)
refiner_assistant = create_assistant("RefinerAssistant", settings.REFINER_ASSISTANT)

# Plan tasks
plan_result: PlanResponse = await self.workers.plan_tasks(objective, main_assistant)

if plan_result.objective_completion:
# Single-task scenario
final_output = plan_result.explanation
self.state.task_exchanges.append(
TaskExchange(role="main_assistant", content=final_output)
)
else:
# Multi-task scenario
tasks = plan_result.tasks if plan_result.tasks else []
self.state.task_exchanges.append(
TaskExchange(role="main_assistant", content="\n".join([t.task for t in tasks]))
)

# Process tasks
results = await self.workers.process_tasks(tasks)
for result in results:
self.state.tasks.append(
Task(task=result.task, prompt=result.prompt, result=result.result)
)
self.state.task_exchanges.append(
TaskExchange(role="sub_assistant", content=result.result)
)

# Summarize results
final_output = await self.workers.summarize_results(
objective, results, refiner_assistant
)
self.state.task_exchanges.append(
TaskExchange(role="refiner_assistant", content=final_output)
)

self._save_exchange_log(objective, final_output)
logger.info("Workflow completed and exchange log saved")

return final_output

return refined_output
except AssistantError as e:
logger.error(f"Assistant error: {str(e)}")
raise WorkflowError(f"Workflow failed due to assistant error: {str(e)}")
except Exception as e:
logger.exception("Unexpected error in workflow execution")
raise WorkflowError(f"Unexpected error in workflow execution: {str(e)}")

def _generate_main_prompt(self, objective: str) -> str:
return (
f"Objective: {objective}\n\n"
f"Current progress:\n{json.dumps(self.state.to_dict(), indent=2)}\n\n"
"Break down this objective into the next specific sub-task, or if the objective is fully achieved, "
"start your response with 'ALL DONE:' followed by the final output."
)

def _generate_sub_task_prompt(self, main_response: str) -> str:
return (
f"Previous tasks: {json.dumps([task.to_dict() for task in self.state.tasks], indent=2)}\n\n"
f"Current task: {main_response}\n\n"
"Execute this task and provide the result. Use the provided functions to create, read, or list files as needed. "
f"All file operations should be relative to the '{self.output_dir}' directory."
)

def _get_assistant_response(self, assistant_type: str, prompt: str) -> str:
try:
assistant_model = getattr(settings, f"{assistant_type.upper()}_ASSISTANT")
assistant = create_assistant(f"{assistant_type.capitalize()}Assistant", assistant_model)
response = get_full_response(assistant, prompt)
logger.info(f"{assistant_type.capitalize()} assistant response received")
self.state.task_exchanges.append(
TaskExchange(role=f"{assistant_type}_assistant", content=response)
)
return response
except Exception as e:
logger.error(f"Error getting response from {assistant_type} assistant: {str(e)}")
raise AssistantError(
f"Error getting response from {assistant_type} assistant: {str(e)}"
)

def _get_refined_output(self, objective: str) -> str:
refiner_prompt = (
f"Original objective: {objective}\n\n"
f"Task breakdown and results: {json.dumps([task.to_dict() for task in self.state.tasks], indent=2)}\n\n"
"Please refine these results into a coherent final output, summarizing the project structure created. "
f"You can use the provided functions to list and read files if needed. All files are in the '{self.output_dir}' directory. "
"Provide your response as a string, not a list or dictionary."
)
return self._get_assistant_response("refiner", refiner_prompt)

def _save_exchange_log(self, objective: str, final_output: str):
"""
Saves the workflow exchange log to a markdown file.
Expand Down
4 changes: 4 additions & 0 deletions src/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@ class ConfigurationError(SAAOrchestratorError):

class PluginError(SAAOrchestratorError):
"""Raised when there's an error with a plugin"""


class WorkerError(Exception):
"""Base exception class for SAA Workers"""
Loading

0 comments on commit 6dd31f4

Please sign in to comment.