Skip to content

Commit

Permalink
Add RouterOrchestrator (#166)
Browse files Browse the repository at this point in the history
* change to break

* add logging

* wip

* wip

* wip

* add router orchestrator

* wip

* wip

* wip

* don't delete

* remove logger from pipeline orchestrator

* module imports

* cr

* better naming
  • Loading branch information
nerdai authored Aug 6, 2024
1 parent 41ae402 commit 93f19a2
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 0 deletions.
2 changes: 2 additions & 0 deletions llama_agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from llama_agents.orchestrators import (
AgentOrchestrator,
PipelineOrchestrator,
OrchestratorRouter,
)
from llama_agents.tools import (
AgentServiceTool,
Expand Down Expand Up @@ -59,6 +60,7 @@
# orchestrators
"AgentOrchestrator",
"PipelineOrchestrator",
"OrchestratorRouter",
# various utils
"AgentServiceTool",
"ServiceAsTool",
Expand Down
2 changes: 2 additions & 0 deletions llama_agents/orchestrators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from llama_agents.orchestrators.agent import AgentOrchestrator
from llama_agents.orchestrators.base import BaseOrchestrator
from llama_agents.orchestrators.pipeline import PipelineOrchestrator
from llama_agents.orchestrators.orchestrator_router import OrchestratorRouter

__all__ = [
"BaseOrchestrator",
"PipelineOrchestrator",
"AgentOrchestrator",
"OrchestratorRouter",
]
109 changes: 109 additions & 0 deletions llama_agents/orchestrators/orchestrator_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from typing import Any, Dict, List, Tuple

from llama_index.core.tools import BaseTool
from llama_index.core.base.base_selector import BaseSelector

from llama_agents.messages.base import QueueMessage
from llama_agents.orchestrators.base import BaseOrchestrator
from llama_agents.types import TaskDefinition, TaskResult

import logging

logger = logging.getLogger(__name__)


class OrchestratorRouter(BaseOrchestrator):
"""Orchestrator that routes between a list of orchestrators.
Given an incoming task, first select the most relevant orchestrator to the
task, and then use that orchestrator to process it.
Attributes:
orchestrators (List[BaseOrchestrator]): The orchestrators to choose from. (must correspond to choices)
choices (List[str]): The descriptions of the orchestrators (must correspond to components)
selector (BaseSelector): The orchestrator selector.
Examples:
```python
from llama_index.core.query_pipeline import QueryPipeline
from llama_agents import (
PipelineOrchestrator,
RouterOrchestrator,
AgentService,
ServiceComponent
)
query_rewrite_server = AgentService(
agent=hyde_agent,
message_queue=message_queue,
description="Used to rewrite queries",
service_name="query_rewrite_agent",
host="127.0.0.1",
port=8011,
)
query_rewrite_server_c = ServiceComponent.from_service_definition(query_rewrite_server)
rag_agent_server = AgentService(
agent=rag_agent,
message_queue=message_queue,
description="rag_agent",
host="127.0.0.1",
port=8012,
)
rag_agent_server_c = ServiceComponent.from_service_definition(rag_agent_server)
# create our multi-agent framework components
pipeline_1 = QueryPipeline(chain=[query_rewrite_server_c])
orchestrator_1 = PipelineOrchestrator(pipeline=pipeline_1)
pipeline_2 = QueryPipeline(chain=[rag_agent_server_c])
orchestrator_2 = PipelineOrchestrator(pipeline=pipeline_2)
orchestrator = RouterOrchestrator(
selector=PydanticSingleSelector.from_defaults(llm=OpenAI()),
orchestrators=[orchestrator_1, orchestrator_2],
choices=["description of orchestrator_1", "description of orchestrator_2"],
)
"""

def __init__(
self,
orchestrators: List[BaseOrchestrator],
choices: List[str],
selector: BaseSelector,
):
self.orchestrators = orchestrators
self.choices = choices
self.selector = selector
self.tasks: Dict[str, int] = {}

async def _select_orchestrator(self, task_def: TaskDefinition) -> BaseOrchestrator:
if task_def.task_id not in self.tasks:
sel_output = await self.selector.aselect(self.choices, task_def.input)
self.tasks[task_def.task_id] = sel_output.ind
# assume one selection
if len(sel_output.selections) != 1:
raise ValueError("Expected one selection")
logger.info("Selected orchestrator for task.")
return self.orchestrators[self.tasks[task_def.task_id]]

async def get_next_messages(
self, task_def: TaskDefinition, tools: List[BaseTool], state: Dict[str, Any]
) -> Tuple[List[QueueMessage], Dict[str, Any]]:
"""Get the next message to process. Returns the message and the new state."""
orchestrator = await self._select_orchestrator(task_def)
return await orchestrator.get_next_messages(task_def, tools, state)

async def add_result_to_state(
self, result: TaskResult, state: Dict[str, Any]
) -> Dict[str, Any]:
"""Add the result of processing a message to the state. Returns the new state.
TODO: figure out a way to properly clear the tasks dictionary when the
highest level Task is actually completed.
"""
if result.task_id not in self.tasks:
raise ValueError("Task not found.")
orchestrator = self.orchestrators[self.tasks[result.task_id]]
res = await orchestrator.add_result_to_state(result, state)
return res

0 comments on commit 93f19a2

Please sign in to comment.