Skip to content

Commit

Permalink
FS-124 Add generalist agent and update supervisior to call this as a …
Browse files Browse the repository at this point in the history
…fall back
  • Loading branch information
evpearce committed Dec 17, 2024
1 parent a64e147 commit 3d16b2c
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 189 deletions.
5 changes: 5 additions & 0 deletions backend/src/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from src.agents.chart_generator_agent import ChartGeneratorAgent
from src.agents.report_agent import ReportAgent
from src.agents.materiality_agent import MaterialityAgent
from src.agents.generalist_agent import GeneralistAgent


config = Config()
Expand All @@ -36,6 +37,10 @@ def get_materiality_agent() -> MaterialityAgent:
return MaterialityAgent(config.materiality_agent_llm, config.materiality_agent_model)


def get_generalist_agent() -> GeneralistAgent:
return GeneralistAgent(config.intent_agent_llm, config.intent_agent_model)


def agent_details(agent: ChatAgent) -> dict:
return {"name": agent.name, "description": agent.description}

Expand Down
39 changes: 39 additions & 0 deletions backend/src/agents/generalist_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import logging
from src.prompts import PromptEngine
from src.agents import ChatAgent, chat_agent
from src.utils import Config
from src.utils.web_utils import (
answer_user_question,
)
import json

logger = logging.getLogger(__name__)
config = Config()

engine = PromptEngine()


@chat_agent(
name="GeneralistAgent",
description="This agent attempts to answer a general question using only the llm",
tools=[],
)
class GeneralistAgent(ChatAgent):
async def invoke(self, utterance) -> str:
try:
answer_to_user = await answer_user_question(utterance, self.llm, self.model)
answer_result = json.loads(answer_to_user)
if answer_result["status"] == "error":
response = {"content": "Error in finding the answer.", "ignore_validation": "false"}
return json.dumps(response, indent=4)
final_answer = json.loads(answer_result["response"]).get("answer", "")
if not final_answer:
response = {"content": "Error in answer format.", "ignore_validation": "false"}
return json.dumps(response, indent=4)
logger.info(f"Answer found successfully {final_answer}")
response = {"content": final_answer, "ignore_validation": "false"}
return json.dumps(response, indent=4)

except Exception as e:
logger.error(f"Error in web_general_search_core: {e}")
return "An error occurred while processing the search query."
89 changes: 33 additions & 56 deletions backend/src/agents/web_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
summarise_pdf_content,
find_info,
create_search_term,
answer_user_question,
)
from .validator_agent import ValidatorAgent
import aiohttp
Expand All @@ -28,59 +27,40 @@

async def web_general_search_core(search_query, llm, model) -> str:
try:
# Step 1: Generate the search term from the user's query
answer_to_user = await answer_user_question(search_query, llm, model)
answer_result = json.loads(answer_to_user)
if answer_result["status"] == "error":
response = {"content": "Error in finding the answer.", "ignore_validation": "false"}
return json.dumps(response, indent=4)
logger.info(f"Answer found successfully {answer_result}")
should_perform_web_search = json.loads(answer_result["response"]).get("should_perform_web_search", "")
if not should_perform_web_search:
final_answer = json.loads(answer_result["response"]).get("answer", "")
if not final_answer:
return "No answer found."
logger.info(f"Answer found successfully {final_answer}")
response = {"content": final_answer, "ignore_validation": "false"}
search_term_json = await create_search_term(search_query, llm, model)
search_term_result = json.loads(search_term_json)

# Step 1: Check if there was an error in generating the search term
if search_term_result.get("status") == "error":
response = {"content": search_term_result.get("error"), "ignore_validation": "false"}
return json.dumps(response, indent=4)
else:
search_term_json = await create_search_term(search_query, llm, model)
search_term_result = json.loads(search_term_json)

# Check if there was an error in generating the search term
if search_term_result.get("status") == "error":
response = {"content": search_term_result.get("error"), "ignore_validation": "false"}
return json.dumps(response, indent=4)
search_term = json.loads(search_term_result["response"]).get("search_term", "")

# Step 2: Perform the search using the generated search term
search_result = await perform_search(search_term, num_results=15)
if search_result.get("status") == "error":
return "No relevant information found on the internet for the given query."
urls = search_result.get("urls", [])
logger.info(f"URLs found: {urls}")

# Step 3: Scrape content from the URLs found
for url in urls:
content = await perform_scrape(url)
if not content:
continue # Skip to the next URL if no content is found
# logger.info(f"Content scraped successfully: {content}")
# Step 4: Summarize the scraped content based on the search term
summary = await perform_summarization(search_term, content, llm, model)
if not summary:
continue # Skip if no summary was generated

# Step 5: Validate the summarization
is_valid = await is_valid_answer(summary, search_term)
if not is_valid:
continue # Skip if the summarization is not valid
response = {
"content": { "content": summary, "url": url },
"ignore_validation": "false"
}
return json.dumps(response, indent=4)
search_term = json.loads(search_term_result["response"]).get("search_term", "")

# Step 2: Perform the search using the generated search term
search_result = await perform_search(search_term, num_results=15)
if search_result.get("status") == "error":
return "No relevant information found on the internet for the given query."
urls = search_result.get("urls", [])
logger.info(f"URLs found: {urls}")

# Step 3: Scrape content from the URLs found
for url in urls:
content = await perform_scrape(url)
if not content:
continue # Skip to the next URL if no content is found
# logger.info(f"Content scraped successfully: {content}")
# Step 4: Summarize the scraped content based on the search term
summary = await perform_summarization(search_term, content, llm, model)
if not summary:
continue # Skip if no summary was generated

# Step 5: Validate the summarization
is_valid = await is_valid_answer(summary, search_term)
if not is_valid:
continue # Skip if the summarization is not valid
response = {"content": {"content": summary, "url": url}, "ignore_validation": "false"}
return json.dumps(response, indent=4)
return "No relevant information found on the internet for the given query."
except Exception as e:
logger.error(f"Error in web_general_search_core: {e}")
return "An error occurred while processing the search query."
Expand Down Expand Up @@ -148,10 +128,7 @@ async def web_scrape_core(url: str) -> str:
return "No content found at the provided URL."
logger.info(f"Content scraped successfully: {content}")
content = content.replace("\n", " ").replace("\r", " ")
response = {
"content": { "content": content, "url": url },
"ignore_validation": "true"
}
response = {"content": {"content": content, "url": url}, "ignore_validation": "true"}
return json.dumps(response, indent=4)
except Exception as e:
return json.dumps({"status": "error", "error": str(e)})
Expand Down
45 changes: 1 addition & 44 deletions backend/src/prompts/templates/answer-user-question.j2
Original file line number Diff line number Diff line change
@@ -1,47 +1,4 @@
You are an expert in providing accurate and complete answers to user queries. Your task is twofold:

1. **Generate a detailed answer** to the user's question based on the provided content or context.
2. **Validate** if the generated answer directly addresses the user's question and is factually accurate.
You are an expert in providing accurate and complete answers to user queries. Your task is to **Generate a detailed answer** to the user's question.

User's question is:
{{ question }}

Once you generate an answer:
- **Check** if the answer completely and accurately addresses the user's question.
- **Determine** if the answer is valid, based on the content provided.
- **Provide** a reason for the validity or invalidity of the answer.
- **Indicate** if a more general web search is required then indicate that.

Reply only in JSON format with the following structure:

```json
{
"answer": "The answer to the user's question, based on the content provided",
"should_perform_web_search": true or false,
"perform_web_search_reason": "A sentence explaining whether the answer is valid or not, and why"
}

### **Example of Usage:**

#### **User’s Question:**
- **"What is Tesla's revenue every year since its creation?"**

#### **Content Provided:**
- A table or a paragraph with data on Tesla's revenue for various years.

#### **LLM’s Response:**

```json
{
"answer": "Tesla's revenue since its creation is: 2008: $15 million, 2009: $30 million, ..., 2023: $81 billion.",
"should_perform_web_search": false,
"perform_web_search_reason": "The answer includes Tesla's revenue for every year since its creation, based on the data provided. No further search is required."
}

{
"answer": "Tesla's revenue for 2010 to 2023 is available, but data for the earlier years is missing.",
"should_perform_web_search": true,
"perform_web_search_reason": "The answer is incomplete because data for Tesla's early years is missing. A more general web search is required."
}

Important: If the question is related to real time data, the LLM should provide should_perform_web_search is true.
10 changes: 5 additions & 5 deletions backend/src/supervisors/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from src.utils import get_scratchpad, update_scratchpad
from src.router import get_agent_for_task
from src.agents import get_validator_agent
from src.agents import get_validator_agent, get_generalist_agent
import json

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -34,10 +34,10 @@ async def solve_all(intent_json) -> None:


async def solve_task(task, scratchpad, attempt=0) -> Tuple[str, str, str]:
if attempt == 5:
raise Exception(unsolvable_response)

agent = await get_agent_for_task(task, scratchpad)
if attempt == 3:
agent = get_generalist_agent()
else:
agent = await get_agent_for_task(task, scratchpad)
if agent is None:
raise Exception(no_agent_response)
logger.info(f"Agent selected: {agent.name}")
Expand Down
7 changes: 6 additions & 1 deletion backend/src/utils/web_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ async def scrape_content(url, limit=100000) -> str:
}
)


async def create_search_term(search_query, llm, model) -> str:
try:
summariser_prompt = engine.load_prompt("create-search-term", question=search_query)
Expand All @@ -81,10 +82,11 @@ async def create_search_term(search_query, llm, model) -> str:
}
)


async def answer_user_question(search_query, llm, model) -> str:
try:
summariser_prompt = engine.load_prompt("answer-user-question", question=search_query)
response = await llm.chat(model, summariser_prompt, "", return_json=True)
response = await llm.chat(model, summariser_prompt, "")
return json.dumps(
{
"status": "success",
Expand All @@ -102,6 +104,7 @@ async def answer_user_question(search_query, llm, model) -> str:
}
)


async def summarise_content(search_query, contents, llm, model) -> str:
try:
summariser_prompt = engine.load_prompt("summariser", question=search_query, content=contents)
Expand All @@ -123,6 +126,7 @@ async def summarise_content(search_query, contents, llm, model) -> str:
}
)


async def summarise_pdf_content(contents, llm, model) -> str:
try:
summariser_prompt = engine.load_prompt("pdf-summariser", content=contents)
Expand All @@ -144,6 +148,7 @@ async def summarise_pdf_content(contents, llm, model) -> str:
}
)


async def perform_math_operation_util(math_query, llm, model) -> str:
try:
math_prompt = engine.load_prompt("math-solver", query=math_query)
Expand Down
45 changes: 45 additions & 0 deletions backend/tests/agents/generalist_agent_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest
from unittest.mock import patch, AsyncMock
import json
from src.agents.generalist_agent import GeneralistAgent


@pytest.mark.asyncio
@patch("src.agents.generalist_agent.answer_user_question", new_callable=AsyncMock)
async def test_generalist_agent(
mock_answer_user_question,
):
mock_answer_user_question.return_value = json.dumps(
{"status": "success", "response": json.dumps({"is_valid": True, "answer": "Example summary."})}
)
generalist_agent = GeneralistAgent("llm", "mock_model")

result = await generalist_agent.invoke("example query")
expected_response = {"content": "Example summary.", "ignore_validation": "false"}
assert json.loads(result) == expected_response


@pytest.mark.asyncio
async def test_generalist_agent_response_error():
generalist_agent = GeneralistAgent("llm", "mock_model")

result = await generalist_agent.invoke("example query")

expected_response = {"content": "Error in finding the answer.", "ignore_validation": "false"}
assert json.loads(result) == expected_response


@pytest.mark.asyncio
@patch("src.agents.generalist_agent.answer_user_question", new_callable=AsyncMock)
async def test_generalist_agent_reponse_format_error(
mock_answer_user_question,
):
mock_answer_user_question.return_value = json.dumps(
{"status": "success", "response": json.dumps({"is_valid": True, "answer_wrong_format": "Example summary."})}
)
generalist_agent = GeneralistAgent("llm", "mock_model")

result = await generalist_agent.invoke("example query")

expected_response = {"content": "Error in answer format.", "ignore_validation": "false"}
assert json.loads(result) == expected_response
Loading

0 comments on commit 3d16b2c

Please sign in to comment.