Skip to content

Commit

Permalink
Merge pull request #1 from PropsAI/decorator
Browse files Browse the repository at this point in the history
Decorator
  • Loading branch information
k11kirky authored Oct 29, 2024
2 parents 384219a + 4eec483 commit 8b5e743
Show file tree
Hide file tree
Showing 27 changed files with 404 additions and 586 deletions.
371 changes: 123 additions & 248 deletions README.md

Large diffs are not rendered by default.

4 changes: 1 addition & 3 deletions agentserve/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
# agentserve/__init__.py
from .agent import Agent, AgentInput
from .agent_server import AgentServer
from .cli import main as cli_main
from .agent_server import AgentServer as app
Binary file added agentserve/__pycache__/__init__.cpython-312.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added agentserve/__pycache__/config.cpython-312.pyc
Binary file not shown.
14 changes: 14 additions & 0 deletions agentserve/agent_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# agentserve/agent_registry.py

class AgentRegistry:
def __init__(self):
self._agent_function = None

def register_agent(self, func):
self._agent_function = func
return func

def get_agent(self):
if not self._agent_function:
raise Exception("No agent function registered")
return self._agent_function
78 changes: 41 additions & 37 deletions agentserve/agent_server.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,41 @@
# agentserve/agent_server.py

from fastapi import FastAPI, HTTPException
from typing import Dict, Any
from rq import Queue
from redis import Redis
from .queues.task_queue import TaskQueue
from .agent_registry import AgentRegistry
from typing import Dict, Any, Optional
from .config import Config
import uuid
import os

class AgentServer:
def __init__(self, agent_class: type):
self.agent = agent_class()
def __init__(self, config: Optional[Config] = None):
self.app = FastAPI()
self.redis_conn = Redis(host=os.getenv('REDIS_HOST', 'redis'), port=6379)
self.task_queue = Queue(connection=self.redis_conn)
self.agent_registry = AgentRegistry()
self.config = config or Config()
self.task_queue = self._initialize_task_queue()
self.agent = self.agent_registry.register_agent
self._setup_routes()


def _initialize_task_queue(self):
task_queue_type = self.config.get('task_queue', 'local').lower()
if task_queue_type == 'celery':
from .celery_task_queue import CeleryTaskQueue
return CeleryTaskQueue(self.config)
elif task_queue_type == 'redis':
from .redis_task_queue import RedisTaskQueue
return RedisTaskQueue(self.config)
else:
from .queues.local_task_queue import LocalTaskQueue
return LocalTaskQueue()

def _setup_routes(self):
@self.app.post("/task/sync")
async def sync_task(task_data: Dict[str, Any]):
try:
result = self.agent._process(task_data)
agent_function = self.agent_registry.get_agent()
result = agent_function(task_data)
return {"result": result}
except ValueError as ve:
# Check if this is a Pydantic validation error
if hasattr(ve, 'errors'):
raise HTTPException(
status_code=400,
Expand All @@ -38,38 +51,29 @@ async def sync_task(task_data: Dict[str, Any]):
@self.app.post("/task/async")
async def async_task(task_data: Dict[str, Any]):
task_id = str(uuid.uuid4())
job = self.task_queue.enqueue(self.agent._process, task_data, job_id=task_id)
agent_function = self.agent_registry.get_agent()
self.task_queue.enqueue(agent_function, task_data, task_id)
return {"task_id": task_id}

@self.app.get("/task/status/{task_id}")
async def get_status(task_id: str):
job = self.task_queue.fetch_job(task_id)
if job:
return {"status": job.get_status()}
else:
status = self.task_queue.get_status(task_id)
if status == 'not_found':
raise HTTPException(status_code=404, detail="Task not found")
return {"status": status}

@self.app.get("/task/result/{task_id}")
async def get_result(task_id: str):
job = self.task_queue.fetch_job(task_id)
if job:
if job.is_finished:
return {"result": job.result}
elif job.is_failed:
# Extract the error information
exc_info = job.exc_info
if isinstance(exc_info, ValueError):
# Check if this was a Pydantic validation error
if hasattr(exc_info, 'errors'):
return {
"status": "failed",
"error": {
"message": "Validation error",
"errors": exc_info.errors()
}
}
return {"status": "failed", "error": str(exc_info)}
try:
result = self.task_queue.get_result(task_id)
if result is not None:
return {"result": result}
else:
return {"status": job.get_status()}
else:
raise HTTPException(status_code=404, detail="Task not found")
status = self.task_queue.get_status(task_id)
return {"status": status}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

def run(self, host="0.0.0.0", port=8000):
import uvicorn
uvicorn.run(self.app, host=host, port=port)
151 changes: 27 additions & 124 deletions agentserve/cli.py
Original file line number Diff line number Diff line change
@@ -1,134 +1,37 @@
# agentserve/cli.py

import click
import os
import shutil
from pathlib import Path

TEMPLATES_DIR = Path(__file__).parent / 'templates'

# Mapping of framework choices to their respective agent import and class names
FRAMEWORKS = {
'openai': {
'agent_import': 'from agents.example_openai_agent import ExampleAgent',
'agent_class': 'ExampleAgent',
'agent_template_filename': 'example_openai_agent.py.tpl'
},
'langchain': {
'agent_import': 'from agents.example_langchain_agent import ExampleAgent',
'agent_class': 'ExampleAgent',
'agent_template_filename': 'example_langchain_agent.py.tpl'
},
'llama': {
'agent_import': 'from agents.example_llama_agent import ExampleAgent',
'agent_class': 'ExampleAgent',
'agent_template_filename': 'example_llamaindex_agent.py.tpl'
},
'blank': {
'agent_import': 'from agents.example_agent import ExampleAgent',
'agent_class': 'ExampleAgent',
'agent_template_filename': 'example_agent.py.tpl'
}
}
from .config import Config

@click.group()
def main():
"""CLI tool for managing AI agents."""
click.echo(click.style("\nWelcome to AgentServe CLI\n\n", fg='green', bold=True))
click.echo("Go to https://github.com/Props/agentserve for more information.\n\n\n")

@main.command()
@click.argument('project_name')
@click.option('--framework', type=click.Choice(FRAMEWORKS.keys()), default='openai', help='Type of agent framework to use.')
def init(project_name, framework):
"""Initialize a new agent project."""
project_path = Path.cwd() / project_name

# Check if the project directory already exists
if project_path.exists():
click.echo(f"Directory '{project_name}' already exists.")
return

# Define the list of target directories to be created
target_dirs = ['agents']

# Create project directory
project_path.mkdir()

# Create subdirectories
for dir_name in target_dirs:
(project_path / dir_name).mkdir()

# Copy and process main.py template
main_tpl_path = TEMPLATES_DIR / 'main.py.tpl'
with open(main_tpl_path, 'r') as tpl_file:
main_content = tpl_file.read()

agent_import = FRAMEWORKS[framework]['agent_import']
agent_class = FRAMEWORKS[framework]['agent_class']
main_content = main_content.replace('{{AGENT_IMPORT}}', agent_import)
main_content = main_content.replace('{{AGENT_CLASS}}', agent_class)

main_dst_path = project_path / 'main.py'
with open(main_dst_path, 'w') as dst_file:
dst_file.write(main_content)

# Copy agent template to agents directory
agent_template_filename = FRAMEWORKS[framework]['agent_template_filename']
agent_src_path = TEMPLATES_DIR / 'agents' / agent_template_filename
agent_dst_path = project_path / 'agents' / agent_template_filename[:-4] # Remove '.tpl' extension
shutil.copyfile(agent_src_path, agent_dst_path)

# Copy .env template
env_tpl_path = TEMPLATES_DIR / '.env.tpl'
env_dst_path = project_path / '.env'
shutil.copyfile(env_tpl_path, env_dst_path)

# Create requirements.txt
requirements_path = project_path / 'requirements.txt'
with open(requirements_path, 'w') as f:
f.write('agentserve\n')
if framework == 'openai':
f.write('openai\n')
elif framework == 'langchain':
f.write('langchain\n')
elif framework == 'llama':
f.write('llama-index\n')

click.echo(f"Initialized new agent project at '{project_path}' with '{framework}' framework.")
click.echo(f" - Now run 'cd {project_name}'")
click.echo(" - Update the .env file with your API keys and other environment variables")
click.echo(" - To generate Dockerfiles, run 'agentserve build'")
click.echo(" - Then run 'agentserve run' to start the API server and worker.")

@main.command()
def build():
"""Generate Dockerfiles."""
project_path = Path.cwd()
docker_dir = project_path / 'docker'
docker_dir.mkdir(exist_ok=True)

templates = {
'Dockerfile.tpl': 'Dockerfile',
'docker-compose.yml.tpl': 'docker-compose.yml'
}

for tpl_name, dst_name in templates.items():
src_path = TEMPLATES_DIR / tpl_name
dst_path = docker_dir / dst_name
shutil.copyfile(src_path, dst_path)

click.echo(f"Dockerfiles have been generated in '{docker_dir}'.")

@main.command()
def run():
"""Run the API server and worker."""
docker_dir = Path.cwd() / 'docker'
if not docker_dir.exists():
click.echo("Docker directory not found. Please run 'agentserve build' first.")
return
os.chdir(docker_dir)
os.system('docker-compose up --build')

if __name__ == '__main__':
main()
@cli.command()
def startworker():
"""Starts the AgentServe worker (if required)."""
config = Config()
task_queue_type = config.get('task_queue', 'local').lower()

if task_queue_type == 'celery':
from .queues.celery_task_queue import CeleryTaskQueue
task_queue = CeleryTaskQueue(config)
# Start the Celery worker
argv = [
'worker',
'--loglevel=info',
'--pool=solo', # Use 'solo' pool to avoid issues on some platforms
]
task_queue.celery_app.worker_main(argv)
elif task_queue_type == 'redis':
# For Redis (RQ), start a worker process
from rq import Worker, Connection
from .queues.redis_task_queue import RedisTaskQueue
task_queue = RedisTaskQueue(config)
with Connection(task_queue.redis_conn):
worker = Worker([task_queue.task_queue])
worker.work(log_level='INFO')
else:
click.echo("No worker required for the 'local' task queue.")
59 changes: 59 additions & 0 deletions agentserve/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# agentserve/config.py

import os
import yaml

class Config:
def __init__(self):
self.config = self._load_config()

def _load_config(self):
# Load from 'agentserve.yaml' if exists
config_path = 'agentserve.yaml'
config = {}
if os.path.exists(config_path):
with open(config_path, 'r') as file:
config = yaml.safe_load(file) or {}

# Override with environment variables
config['task_queue'] = os.getenv('AGENTSERVE_TASK_QUEUE', config.get('task_queue', 'local'))

# Celery configuration
celery_broker_url = os.getenv('AGENTSERVE_CELERY_BROKER_URL')
if celery_broker_url:
config.setdefault('celery', {})['broker_url'] = celery_broker_url

# Redis configuration
redis_host = os.getenv('AGENTSERVE_REDIS_HOST')
redis_port = os.getenv('AGENTSERVE_REDIS_PORT')
if redis_host or redis_port:
redis_config = config.setdefault('redis', {})
if redis_host:
redis_config['host'] = redis_host
if redis_port:
redis_config['port'] = int(redis_port)

# Server configuration
server_host = os.getenv('AGENTSERVE_SERVER_HOST')
server_port = os.getenv('AGENTSERVE_SERVER_PORT')
if server_host or server_port:
server_config = config.setdefault('server', {})
if server_host:
server_config['host'] = server_host
if server_port:
server_config['port'] = int(server_port)

return config

def get(self, key, default=None):
return self.config.get(key, default)

def get_nested(self, *keys, default=None):
value = self.config
for key in keys:
if not isinstance(value, dict):
return default
value = value.get(key)
if value is None:
return default
return value
Binary file not shown.
Binary file not shown.
39 changes: 39 additions & 0 deletions agentserve/queues/calery_task_queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# agentserve/celery_task_queue.py

from typing import Any, Dict
from .task_queue import TaskQueue
from ..config import Config
class CeleryTaskQueue(TaskQueue):
def __init__(self, config: Config):
try:
from celery import Celery
except ImportError:
raise ImportError("CeleryTaskQueue requires the 'celery' package. Please install it.")

broker_url = config.get('celery', {}).get('broker_url', 'pyamqp://guest@localhost//')
self.celery_app = Celery('agent_server', broker=broker_url)
self._register_tasks()

def _register_tasks(self):
@self.celery_app.task(name='agent_task')
def agent_task(task_data):
from .agent_registry import AgentRegistry
agent_registry = AgentRegistry()
agent_function = agent_registry.get_agent()
return agent_function(task_data)

def enqueue(self, agent_function, task_data: Dict[str, Any], task_id: str):
# Since the agent task is registered with Celery, we just send the task name
self.celery_app.send_task('agent_task', args=[task_data], task_id=task_id)

def get_status(self, task_id: str) -> str:
result = self.celery_app.AsyncResult(task_id)
return result.status

def get_result(self, task_id: str) -> Any:
result = self.celery_app.AsyncResult(task_id)
if result.state == 'SUCCESS':
return result.result
if result.state == 'FAILURE':
raise Exception(str(result.result))
return None
Loading

0 comments on commit 8b5e743

Please sign in to comment.