diff --git a/packages/sdk/python/agent_protocol/__init__.py b/packages/sdk/python/agent_protocol/__init__.py new file mode 100644 index 00000000..398e78cb --- /dev/null +++ b/packages/sdk/python/agent_protocol/__init__.py @@ -0,0 +1,18 @@ +from .agent import Agent, StepHandler, TaskHandler, base_router as router +from .models import Artifact, Status, StepRequestBody, TaskRequestBody +from .db import Step, Task, TaskDB + + +__all__ = [ + "Agent", + "Artifact", + "Status", + "Step", + "StepHandler", + "StepRequestBody", + "Task", + "TaskDB", + "TaskHandler", + "TaskRequestBody", + "router", +] diff --git a/packages/sdk/python/agent_protocol/agent.py b/packages/sdk/python/agent_protocol/agent.py new file mode 100644 index 00000000..bf7da5f5 --- /dev/null +++ b/packages/sdk/python/agent_protocol/agent.py @@ -0,0 +1,260 @@ +import asyncio +import os +from uuid import uuid4 + +import aiofiles +from fastapi import APIRouter, UploadFile, Form, File +from fastapi.responses import FileResponse +from hypercorn.asyncio import serve +from hypercorn.config import Config +from typing import Callable, List, Optional, Annotated, Coroutine, Any + +from .db import InMemoryTaskDB, Task, TaskDB, Step +from .server import app +from .models import ( + TaskRequestBody, + StepRequestBody, + Artifact, + Status, + TaskListResponse, + TaskStepsListResponse, + Pagination, +) + + +StepHandler = Callable[[Step], Coroutine[Any, Any, Step]] +TaskHandler = Callable[[Task], Coroutine[Any, Any, None]] + + +_task_handler: Optional[TaskHandler] +_step_handler: Optional[StepHandler] + + +base_router = APIRouter() + + +@base_router.post("/ap/v1/agent/tasks", response_model=Task, tags=["agent"]) +async def create_agent_task(body: TaskRequestBody | None = None) -> Task: + """ + Creates a task for the agent. + """ + if not _task_handler: + raise Exception("Task handler not defined") + + task = await Agent.db.create_task( + input=body.input if body else None, + additional_input=body.additional_input if body else None, + ) + await _task_handler(task) + + return task + + +@base_router.get("/ap/v1/agent/tasks", response_model=TaskListResponse, tags=["agent"]) +async def list_agent_tasks_ids(page_size: int = 10, current_page: int = 1) -> List[str]: + """ + List all tasks that have been created for the agent. + """ + tasks = await Agent.db.list_tasks() + start_index = (current_page - 1) * page_size + end_index = start_index + page_size + return TaskListResponse( + tasks=tasks[start_index:end_index], + pagination=Pagination( + total_items=len(tasks), + total_pages=len(tasks) // page_size, + current_page=current_page, + page_size=page_size, + ), + ) + + +@base_router.get("/ap/v1/agent/tasks/{task_id}", response_model=Task, tags=["agent"]) +async def get_agent_task(task_id: str) -> Task: + """ + Get details about a specified agent task. + """ + return await Agent.db.get_task(task_id) + + +@base_router.get( + "/ap/v1/agent/tasks/{task_id}/steps", + response_model=TaskStepsListResponse, + tags=["agent"], +) +async def list_agent_task_steps( + task_id: str, page_size: int = 10, current_page: int = 1 +) -> List[str]: + """ + List all steps for the specified task. + """ + task = await Agent.db.get_task(task_id) + start_index = (current_page - 1) * page_size + end_index = start_index + page_size + return TaskStepsListResponse( + steps=task.steps[start_index:end_index], + pagination=Pagination( + total_items=len(task.steps), + total_pages=len(task.steps) // page_size, + current_page=current_page, + page_size=page_size, + ), + ) + + +@base_router.post( + "/ap/v1/agent/tasks/{task_id}/steps", + response_model=Step, + tags=["agent"], +) +async def execute_agent_task_step( + task_id: str, + body: StepRequestBody | None = None, +) -> Step: + """ + Execute a step in the specified agent task. + """ + if not _step_handler: + raise Exception("Step handler not defined") + + task = await Agent.db.get_task(task_id) + step = next(filter(lambda x: x.status == Status.created, task.steps), None) + + if not step: + raise Exception("No steps to execute") + + step.status = Status.running + + step.input = body.input if body else None + step.additional_input = body.additional_input if body else None + + step = await _step_handler(step) + + step.status = Status.completed + return step + + +@base_router.get( + "/ap/v1/agent/tasks/{task_id}/steps/{step_id}", + response_model=Step, + tags=["agent"], +) +async def get_agent_task_step(task_id: str, step_id: str) -> Step: + """ + Get details about a specified task step. + """ + return await Agent.db.get_step(task_id, step_id) + + +@base_router.get( + "/ap/v1/agent/tasks/{task_id}/artifacts", + response_model=List[Artifact], + tags=["agent"], +) +async def list_agent_task_artifacts(task_id: str) -> List[Artifact]: + """ + List all artifacts for the specified task. + """ + task = await Agent.db.get_task(task_id) + return task.artifacts + + +@base_router.post( + "/ap/v1/agent/tasks/{task_id}/artifacts", + response_model=Artifact, + tags=["agent"], +) +async def upload_agent_task_artifacts( + task_id: str, + file: Annotated[UploadFile, File()], + relative_path: Annotated[Optional[str], Form()] = None, +) -> Artifact: + """ + Upload an artifact for the specified task. + """ + file_name = file.filename or str(uuid4()) + await Agent.db.get_task(task_id) + artifact = await Agent.db.create_artifact( + task_id=task_id, + agent_created=False, + file_name=file_name, + relative_path=relative_path, + ) + + path = Agent.get_artifact_folder(task_id, artifact) + if not os.path.exists(path): + os.makedirs(path) + + async with aiofiles.open(os.path.join(path, file_name), "wb") as f: + while content := await file.read(1024 * 1024): # async read chunk ~1MiB + await f.write(content) + + return artifact + + +@base_router.get( + "/ap/v1/agent/tasks/{task_id}/artifacts/{artifact_id}", + tags=["agent"], +) +async def download_agent_task_artifacts(task_id: str, artifact_id: str) -> FileResponse: + """ + Download the specified artifact. + """ + artifact = await Agent.db.get_artifact(task_id, artifact_id) + path = Agent.get_artifact_path(task_id, artifact) + return FileResponse( + path=path, media_type="application/octet-stream", filename=artifact.file_name + ) + + +class Agent: + db: TaskDB = InMemoryTaskDB() + workspace: str = os.getenv("AGENT_WORKSPACE", "workspace") + + @staticmethod + def setup_agent(task_handler: TaskHandler, step_handler: StepHandler): + """ + Set the agent's task and step handlers. + """ + global _task_handler + _task_handler = task_handler + + global _step_handler + _step_handler = step_handler + + return Agent + + @staticmethod + def get_workspace(task_id: str) -> str: + """ + Get the workspace path for the specified task. + """ + return os.path.join(os.getcwd(), Agent.workspace, task_id) + + @staticmethod + def get_artifact_folder(task_id: str, artifact: Artifact) -> str: + """ + Get the artifact path for the specified task and artifact. + """ + workspace_path = Agent.get_workspace(task_id) + relative_path = artifact.relative_path or "" + return os.path.join(workspace_path, relative_path) + + @staticmethod + def get_artifact_path(task_id: str, artifact: Artifact) -> str: + """ + Get the artifact path for the specified task and artifact. + """ + return os.path.join( + Agent.get_artifact_folder(task_id, artifact), artifact.file_name + ) + + @staticmethod + def start(port: int = 8000, router: APIRouter = base_router): + """ + Start the agent server. + """ + config = Config() + config.bind = [f"localhost:{port}"] # As an example configuration setting + app.include_router(router) + asyncio.run(serve(app, config)) diff --git a/packages/sdk/python/agent_protocol/cli.py b/packages/sdk/python/agent_protocol/cli.py new file mode 100644 index 00000000..bddf4210 --- /dev/null +++ b/packages/sdk/python/agent_protocol/cli.py @@ -0,0 +1,25 @@ +import click + +from agent_protocol.utils.compliance import check_compliance + + +@click.group() +def cli(): + pass + + +@cli.command( + "test", + context_settings=dict( + ignore_unknown_options=True, + ), +) +@click.option("--url", "-u", type=str, required=True, help="URL of the Agent API") +@click.argument("args", nargs=-1, type=click.UNPROCESSED) +def _check_compliance(url: str, args: list): + """ + This script checks if the Agent API is Agent Protocol compliant. + + In the background it runs pytest, you can pass additional arguments to pytest. + """ + check_compliance(url, args) diff --git a/packages/sdk/python/agent_protocol/db.py b/packages/sdk/python/agent_protocol/db.py new file mode 100644 index 00000000..50bc3474 --- /dev/null +++ b/packages/sdk/python/agent_protocol/db.py @@ -0,0 +1,182 @@ +import uuid +from abc import ABC +from typing import Dict, List, Optional, Any +from .models import Task as APITask, Step as APIStep, Artifact, Status, NotFoundResponse + + +class Step(APIStep): + additional_properties: Optional[Dict[str, str]] = None + + +class Task(APITask): + steps: List[Step] = [] + + +class NotFoundException(Exception): + """ + Exception raised when a resource is not found. + """ + + def __init__(self, item_name: str, item_id: str): + self.item_name = item_name + self.item_id = item_id + super().__init__(NotFoundResponse( + message=f"{item_name} with {item_id} not found." + )) + + +class TaskDB(ABC): + async def create_task( + self, + input: Optional[str], + additional_input: Any = None, + artifacts: Optional[List[Artifact]] = None, + steps: Optional[List[Step]] = None, + ) -> Task: + raise NotImplementedError + + async def create_step( + self, + task_id: str, + name: Optional[str] = None, + input: Optional[str] = None, + is_last: bool = False, + additional_properties: Optional[Dict[str, str]] = None, + artifacts: List[Artifact] = [], + ) -> Step: + raise NotImplementedError + + async def create_artifact( + self, + task_id: str, + file_name: str, + agent_created: bool = True, + relative_path: Optional[str] = None, + step_id: Optional[str] = None, + ) -> Artifact: + raise NotImplementedError + + async def get_task(self, task_id: str) -> Task: + raise NotImplementedError + + async def get_step(self, task_id: str, step_id: str) -> Step: + raise NotImplementedError + + async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact: + raise NotImplementedError + + async def list_tasks(self) -> List[Task]: + raise NotImplementedError + + async def list_steps( + self, task_id: str, status: Optional[Status] = None + ) -> List[Step]: + raise NotImplementedError + + +class InMemoryTaskDB(TaskDB): + _tasks: Dict[str, Task] = {} + + async def create_task( + self, + input: Optional[str], + additional_input: Any = None, + artifacts: Optional[List[Artifact]] = None, + steps: Optional[List[Step]] = None, + ) -> Task: + if not steps: + steps = [] + if not artifacts: + artifacts = [] + task_id = str(uuid.uuid4()) + task = Task( + task_id=task_id, + input=input, + steps=steps, + artifacts=artifacts, + additional_input=additional_input, + ) + self._tasks[task_id] = task + return task + + async def create_step( + self, + task_id: str, + name: Optional[str] = None, + input: Optional[str] = None, + is_last=False, + additional_properties: Optional[Dict[str, Any]] = None, + artifacts: List[Artifact] = [], + ) -> Step: + step_id = str(uuid.uuid4()) + step = Step( + task_id=task_id, + step_id=step_id, + name=name, + input=input, + status=Status.created, + is_last=is_last, + additional_properties=additional_properties, + artifacts=artifacts, + ) + task = await self.get_task(task_id) + task.steps.append(step) + return step + + async def get_task(self, task_id: str) -> Task: + task = self._tasks.get(task_id, None) + if not task: + raise NotFoundException("Task", task_id) + return task + + async def get_step(self, task_id: str, step_id: str) -> Step: + task = await self.get_task(task_id) + step = next(filter(lambda s: s.task_id == task_id, task.steps), None) + if not step: + raise NotFoundException("Step", step_id) + return step + + async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact: + task = await self.get_task(task_id) + artifact = next( + filter(lambda a: a.artifact_id == artifact_id, task.artifacts), None + ) + if not artifact: + raise NotFoundException("Artifact", artifact_id) + return artifact + + async def create_artifact( + self, + task_id: str, + file_name: str, + agent_created: bool = True, + relative_path: Optional[str] = None, + step_id: Optional[str] = None, + ) -> Artifact: + artifact_id = str(uuid.uuid4()) + artifact = Artifact( + artifact_id=artifact_id, + agent_created=agent_created, + file_name=file_name, + relative_path=relative_path + ) + task = await self.get_task(task_id) + task.artifacts.append(artifact) + + if step_id: + step = await self.get_step(task_id, step_id) + step.artifacts.append(artifact) + + return artifact + + async def list_tasks(self) -> List[Task]: + return [task for task in self._tasks.values()] + + async def list_steps( + self, task_id: str, status: Optional[Status] = None + ) -> List[Step]: + task = await self.get_task(task_id) + steps = task.steps + if status: + steps = list(filter(lambda s: s.status == status, steps)) + return steps diff --git a/packages/sdk/python/agent_protocol/middlewares.py b/packages/sdk/python/agent_protocol/middlewares.py new file mode 100644 index 00000000..e9999392 --- /dev/null +++ b/packages/sdk/python/agent_protocol/middlewares.py @@ -0,0 +1,11 @@ +from fastapi import Request +from fastapi.responses import JSONResponse +from agent_protocol.db import NotFoundException + +async def not_found_exception_handler( + request: Request, exc: NotFoundException +) -> JSONResponse: + return JSONResponse( + content={"message": f"{exc.item_name} with {exc.item_id} not found."}, + status_code=404, + ) \ No newline at end of file diff --git a/packages/sdk/python/agent_protocol/models.py b/packages/sdk/python/agent_protocol/models.py new file mode 100644 index 00000000..6e3648a0 --- /dev/null +++ b/packages/sdk/python/agent_protocol/models.py @@ -0,0 +1,219 @@ +# generated by fastapi-codegen: +# filename: openapi.yml +# timestamp: 2023-09-16T00:59:36+00:00 + +from __future__ import annotations + +from enum import Enum +from typing import List, Optional + +from pydantic import BaseModel, Field + + +class NotFoundResponse(BaseModel): + message: str = Field( + ..., + description="Message stating the entity was not found", + example="Unable to find entity with the provided id", + ) + + +class Pagination(BaseModel): + total_items: int = Field(..., description="Total number of items.", example=42) + total_pages: int = Field(..., description="Total number of pages.", example=97) + current_page: int = Field(..., description="Current_page page number.", example=1) + page_size: int = Field(..., description="Number of items per page.", example=25) + + +class TaskInput(BaseModel): + pass + + +class Artifact(BaseModel): + artifact_id: str = Field( + ..., + description="ID of the artifact.", + example="b225e278-8b4c-4f99-a696-8facf19f0e56", + ) + agent_created: bool = Field( + ..., + description="Whether the artifact has been created by the agent.", + example=False, + ) + file_name: str = Field( + ..., description="Filename of the artifact.", example="main.py" + ) + relative_path: Optional[str] = Field( + None, + description="Relative path of the artifact in the agent's workspace.", + example="python/code/", + ) + + +class ArtifactUpload(BaseModel): + file: bytes = Field( + ..., description="File to upload.", example="binary representation of file" + ) + relative_path: Optional[str] = Field( + None, + description="Relative path of the artifact in the agent's workspace.", + example="python/code", + ) + + +class StepInput(BaseModel): + pass + + +class StepOutput(BaseModel): + pass + + +class TaskRequestBody(BaseModel): + input: Optional[str] = Field( + None, + description="Input prompt for the task.", + example="Write 'Washington' to the file 'output.txt'.", + ) + additional_input: Optional[TaskInput] = None + + +class Task(TaskRequestBody): + task_id: str = Field( + ..., + description="The ID of the task.", + example="50da533e-3904-4401-8a07-c49adf88b5eb", + ) + artifacts: List[Artifact] = Field( + ..., + description="A list of artifacts that the task has produced.", + example=[ + "7a49f31c-f9c6-4346-a22c-e32bc5af4d8e", + "ab7b4091-2560-4692-a4fe-d831ea3ca7d6", + ], + ) + + +class StepRequestBody(BaseModel): + input: Optional[str] = Field( + None, + description="Input prompt for the step.", + example="Write the words you receive to the file 'output.txt'.", + ) + additional_input: Optional[StepInput] = None + + +class Status(Enum): + created = "created" + running = "running" + completed = "completed" + + +class Step(StepRequestBody): + task_id: str = Field( + ..., + description="The ID of the task this step belongs to.", + example="50da533e-3904-4401-8a07-c49adf88b5eb", + ) + step_id: str = Field( + ..., + description="The ID of the task step.", + example="6bb1801a-fd80-45e8-899a-4dd723cc602e", + ) + name: Optional[str] = Field( + None, description="The name of the task step.", example="Write to file" + ) + status: Status = Field( + ..., description="The status of the task step.", example="created" + ) + output: Optional[str] = Field( + None, + description="Output of the task step.", + example="I am going to use the write_to_file command and write Washington to a file called output.txt dict: + return TaskRequestBody(input="test").dict() + + def test_create_agent_task(self, url): + response = requests.post(f"{url}/ap/v1/agent/tasks", json=self.task_data) + assert response.status_code == 200 + assert Task(**response.json()).task_id + + def test_list_agent_tasks_ids(self, url): + response = requests.get(f"{url}/ap/v1/agent/tasks") + assert response.status_code == 200 + assert isinstance(response.json(), list) + + def test_get_agent_task(self, url): + # Create task + response = requests.post(f"{url}/ap/v1/agent/tasks", json=self.task_data) + task_id = response.json()["task_id"] + response = requests.get(f"{url}/ap/v1/agent/tasks/{task_id}") + assert response.status_code == 200 + assert Task(**response.json()).task_id == task_id + + def test_list_agent_task_steps(self, url): + # Create task + response = requests.post(f"{url}/ap/v1/agent/tasks", json=self.task_data) + task_id = response.json()["task_id"] + response = requests.get(f"{url}/ap/v1/agent/tasks/{task_id}/steps") + assert response.status_code == 200 + assert isinstance(response.json(), list) + + def test_execute_agent_task_step(self, url): + # Create task + response = requests.post(f"{url}/ap/v1/agent/tasks", json=self.task_data) + task_id = response.json()["task_id"] + step_body = StepRequestBody(input="test") + response = requests.post( + f"{url}/ap/v1/agent/tasks/{task_id}/steps", json=step_body.dict() + ) + assert response.status_code == 200 + + def test_list_artifacts(self, url): + response = requests.post(f"{url}/ap/v1/agent/tasks", json=self.task_data) + task_id = response.json()["task_id"] + response = requests.get(f"{url}/ap/v1/agent/tasks/{task_id}/artifacts") + assert response.status_code == 200 + assert isinstance(response.json(), list) + + def test_get_agent_task_step(self, url): + # Create task + response = requests.post(f"{url}/ap/v1/agent/tasks", json=self.task_data) + task_id = response.json()["task_id"] + # Get steps + response = requests.get(f"{url}/ap/v1/agent/tasks/{task_id}/steps") + step_id = response.json()[0] + response = requests.get(f"{url}/ap/v1/agent/tasks/{task_id}/steps/{step_id}") + assert response.status_code == 200 + assert Step(**response.json()).step_id == step_id + + +def provide_url_scheme(url: str, default_scheme: str = "https") -> str: + """Make sure we have valid url scheme. + Params: + url : string : the URL + default_scheme : string : default scheme to use, e.g. 'https' + Returns: + string : updated url with validated/attached scheme + """ + if not url: + return url + + if "localhost" in url or "127.0.0.1" in url: + default_scheme = "http" + + has_scheme = ":" in url[:7] + is_universal_scheme = url.startswith("//") + is_file_path = url == "-" or (url.startswith("/") and not is_universal_scheme) + if has_scheme or is_file_path: + return url + if is_universal_scheme: + return default_scheme + ":" + url + return default_scheme + "://" + url + + +def check_compliance(url, additional_pytest_args): + url = provide_url_scheme(url) + exit_code = pytest.main( + [ + "-v", + __file__, + "--url", + url, + "-W", + "ignore:Module already imported:pytest.PytestWarning", + ] + + list(additional_pytest_args) + ) + assert exit_code == 0, "Your Agent API isn't compliant with the agent protocol"