Skip to content

Commit

Permalink
feat: add Agent input validation
Browse files Browse the repository at this point in the history
  • Loading branch information
k11kirky committed Oct 25, 2024
1 parent a4c3cca commit b2b31c4
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 9 deletions.
81 changes: 81 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ The goal of AgentServe is to provide the easiest way to take an local agent to p
- **Framework Agnostic:** AgentServe supports multiple agent frameworks (OpenAI, LangChain, LlamaIndex, and Blank).
- **Dockerized:** The output is a single docker image that you can deploy anywhere.
- **Easy to Use:** AgentServe provides a CLI tool to initialize and setup your AI agent projects.
- **Schema Validation:** Define input schemas for your agents using AgentInput to ensure data consistency and validation.

## Requirements

Expand Down Expand Up @@ -161,6 +162,86 @@ Get the result of a task.

- `result`: The result of the task.

## Defining Input Schemas

AgentServe uses AgentInput (an alias for Pydantic's BaseModel) to define and validate the input schemas for your agents. This ensures that the data received by your agents adheres to the expected structure, enhancing reliability and developer experience.
### Subclassing AgentInput
To define a custom input schema for your agent, subclass AgentInput and specify the required fields.

**Example:**

```python
# agents/custom_agent.py
from agentserve.agent import Agent, AgentInput
from typing import Optional, Dict, Any

class CustomTaskSchema(AgentInput):
input_text: str
parameters: Optional[Dict[str, Any]] = None

class CustomAgent(Agent):
input_schema = CustomTaskSchema

def process(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
# Implement your processing logic here
input_text = task_data["input_text"]
parameters = task_data.get("parameters", {})
# Example processing
processed_text = input_text.upper() # Simple example
return {"processed_text": processed_text, "parameters": parameters}
```

### Updating Your Agent

When creating your agent, assign your custom schema to the input_schema attribute. This ensures that all incoming task_data is validated against your defined schema before processing.

**Steps:**

1. Define the Input Schema:

```python
from agentserve.agent import Agent, AgentInput
from typing import Optional, Dict, Any

class MyTaskSchema(AgentInput):
prompt: str
settings: Optional[Dict[str, Any]] = None
```

2. Implement the Agent:

```python
class MyAgent(Agent):
input_schema = MyTaskSchema

def process(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
prompt = task_data["prompt"]
settings = task_data.get("settings", {})
# Your processing logic here
response = {"response": f"Echo: {prompt}", "settings": settings}
return response
```

### Handling Validation Errors

AgentServe will automatically validate incoming task_data against the defined input_schema. If the data does not conform to the schema, a 400 Bad Request error will be returned with details about the validation failure.

**Example Response:**

```json
{
"detail": [
{
"loc": ["body", "prompt"],
"msg": "field required",
"type": "value_error.missing"
}
]
}
```

Ensure that your clients provide data that matches the schema to avoid validation errors.

## CLI Usage

### Init Command (for new projects)
Expand Down
12 changes: 11 additions & 1 deletion agentserve/agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
# agentserve/agent.py
from typing import Dict, Any
from typing import Dict, Any, Type
from pydantic import BaseModel

AgentInput = BaseModel # Alias BaseModel to AgentInput

class Agent:
input_schema: Type[AgentInput] = AgentInput

def _process(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
# Validate task_data against input_schema
validated_data = self.input_schema(**task_data).dict()
return self._process(validated_data)

def process(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
"""
User-defined method to process the incoming task data.
Expand Down
9 changes: 5 additions & 4 deletions agentserve/agent_server.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# agentserve/agent_server.py

from fastapi import FastAPI, HTTPException
from typing import Dict, Any, AsyncGenerator
from typing import Dict, Any
from rq import Queue
from redis import Redis
from fastapi.responses import StreamingResponse
import uuid
import os

Expand All @@ -20,15 +19,17 @@ def _setup_routes(self):
@self.app.post("/task/sync")
async def sync_task(task_data: Dict[str, Any]):
try:
result = self.agent.process(task_data)
result = self.agent._process(task_data)
return {"result": result}
except ValueError as ve:
raise HTTPException(status_code=400, detail=str(ve))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@self.app.post("/task/async")
async def async_task(task_data: Dict[str, Any]):
task_id = str(uuid.uuid4())
job = self.task_queue.enqueue(self.agent.process, task_data, job_id=task_id)
job = self.task_queue.enqueue(self.agent._process, task_data, job_id=task_id)
return {"task_id": task_id}

@self.app.get("/task/status/{task_id}")
Expand Down
6 changes: 5 additions & 1 deletion agentserve/templates/agents/example_agent.py.tpl
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from agentserve import Agent
from agentserve import Agent, AgentInput

class ExampleInput(AgentInput):
prompt: str

class ExampleAgent(Agent):
input_schema = ExampleInput
def process(self, task_data):
return ""

7 changes: 6 additions & 1 deletion agentserve/templates/agents/example_langchain_agent.py.tpl
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from agentserve import Agent
from agentserve import Agent, AgentInput
from langchain import OpenAI

class ExampleInput(AgentInput):
prompt: str

class ExampleAgent(Agent):
input_schema = ExampleInput

def __init__(self):
self.client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))

Expand Down
7 changes: 6 additions & 1 deletion agentserve/templates/agents/example_llamaindex_agent.py.tpl
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from agentserve import Agent
from agentserve import Agent, AgentInput
from llama_index import GPTSimpleVectorIndex, SimpleDirectoryReader
import os

class ExampleInput(AgentInput):
query: str

class ExampleAgent(Agent):
input_schema = ExampleInput

def process(self, task_data):
# Load documents from a directory
documents = SimpleDirectoryReader('data').load_data()
Expand Down
7 changes: 6 additions & 1 deletion agentserve/templates/agents/example_openai_agent.py.tpl
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from agentserve import Agent
from agentserve import Agent, AgentInput
from openai import OpenAI

class ExampleInput(AgentInput):
prompt: str

class ExampleAgent(Agent):
input_schema = ExampleInput

def process(self, task_data):
client = OpenAI()
response = client.chat.completions.create(
Expand Down

0 comments on commit b2b31c4

Please sign in to comment.