Skip to content

Commit

Permalink
FS-124 Add generalist agent and update supervisor to call this as a f…
Browse files Browse the repository at this point in the history
…all back (#55)
  • Loading branch information
evpearce authored Dec 19, 2024
1 parent a64e147 commit 72d8eaa
Show file tree
Hide file tree
Showing 9 changed files with 157 additions and 235 deletions.
5 changes: 5 additions & 0 deletions backend/src/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from src.agents.chart_generator_agent import ChartGeneratorAgent
from src.agents.report_agent import ReportAgent
from src.agents.materiality_agent import MaterialityAgent
from src.agents.generalist_agent import GeneralistAgent


config = Config()
Expand All @@ -36,6 +37,10 @@ def get_materiality_agent() -> MaterialityAgent:
return MaterialityAgent(config.materiality_agent_llm, config.materiality_agent_model)


def get_generalist_agent() -> GeneralistAgent:
return GeneralistAgent(config.intent_agent_llm, config.intent_agent_model)


def agent_details(agent: ChatAgent) -> dict:
return {"name": agent.name, "description": agent.description}

Expand Down
55 changes: 55 additions & 0 deletions backend/src/agents/generalist_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import logging
from src.prompts import PromptEngine
from src.agents import ChatAgent, chat_agent
from src.utils import Config
import json

logger = logging.getLogger(__name__)
config = Config()

engine = PromptEngine()


@chat_agent(
name="GeneralistAgent",
description="This agent attempts to answer a general question using only the llm",
tools=[],
)
class GeneralistAgent(ChatAgent):
async def invoke(self, utterance) -> str:
try:
answer_to_user = await answer_user_question(utterance, self.llm, self.model)
answer_result = json.loads(answer_to_user)
final_answer = json.loads(answer_result["response"]).get("answer", "")
if not final_answer:
response = {"content": "Error in answer format.", "ignore_validation": "false"}
return json.dumps(response, indent=4)
logger.info(f"Answer found successfully {final_answer}")
response = {"content": final_answer, "ignore_validation": "false"}
return json.dumps(response, indent=4)

except Exception as e:
logger.error(f"Error in web_general_search_core: {e}")
return "An error occurred while processing the search query."


async def answer_user_question(search_query, llm, model) -> str:
try:
summariser_prompt = engine.load_prompt("generalist-answer", question=search_query)
response = await llm.chat(model, summariser_prompt, "")
return json.dumps(
{
"status": "success",
"response": response,
"error": None,
}
)
except Exception as e:
logger.error(f"Error during create search term: {e}")
return json.dumps(
{
"status": "error",
"response": None,
"error": str(e),
}
)
99 changes: 32 additions & 67 deletions backend/src/agents/web_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@
summarise_pdf_content,
find_info,
create_search_term,
answer_user_question,
)
from .validator_agent import ValidatorAgent
import aiohttp
import io
from pypdf import PdfReader
import json
from typing import Dict, Any
from typing import Any

logger = logging.getLogger(__name__)
config = Config()
Expand All @@ -28,59 +27,37 @@

async def web_general_search_core(search_query, llm, model) -> str:
try:
# Step 1: Generate the search term from the user's query
answer_to_user = await answer_user_question(search_query, llm, model)
answer_result = json.loads(answer_to_user)
if answer_result["status"] == "error":
response = {"content": "Error in finding the answer.", "ignore_validation": "false"}
return json.dumps(response, indent=4)
logger.info(f"Answer found successfully {answer_result}")
should_perform_web_search = json.loads(answer_result["response"]).get("should_perform_web_search", "")
if not should_perform_web_search:
final_answer = json.loads(answer_result["response"]).get("answer", "")
if not final_answer:
return "No answer found."
logger.info(f"Answer found successfully {final_answer}")
response = {"content": final_answer, "ignore_validation": "false"}
return json.dumps(response, indent=4)
else:
search_term_json = await create_search_term(search_query, llm, model)
search_term_result = json.loads(search_term_json)

# Check if there was an error in generating the search term
if search_term_result.get("status") == "error":
response = {"content": search_term_result.get("error"), "ignore_validation": "false"}
return json.dumps(response, indent=4)
search_term = json.loads(search_term_result["response"]).get("search_term", "")

# Step 2: Perform the search using the generated search term
search_result = await perform_search(search_term, num_results=15)
if search_result.get("status") == "error":
return "No relevant information found on the internet for the given query."
urls = search_result.get("urls", [])
logger.info(f"URLs found: {urls}")

# Step 3: Scrape content from the URLs found
for url in urls:
content = await perform_scrape(url)
if not content:
continue # Skip to the next URL if no content is found
# logger.info(f"Content scraped successfully: {content}")
# Step 4: Summarize the scraped content based on the search term
summary = await perform_summarization(search_term, content, llm, model)
if not summary:
continue # Skip if no summary was generated

# Step 5: Validate the summarization
is_valid = await is_valid_answer(summary, search_term)
if not is_valid:
continue # Skip if the summarization is not valid
response = {
"content": { "content": summary, "url": url },
"ignore_validation": "false"
}
return json.dumps(response, indent=4)
search_term_json = await create_search_term(search_query, llm, model)
search_term_result = json.loads(search_term_json)
search_term = json.loads(search_term_result["response"]).get("search_term", "")

# Step 1: Perform the search using the generated search term
search_result_json = await search_urls(search_query, num_results=15)
search_result = json.loads(search_result_json)

if search_result.get("status") == "error":
return "No relevant information found on the internet for the given query."
urls = search_result.get("urls", [])
logger.info(f"URLs found: {urls}")

# Step 2: Scrape content from the URLs found
for url in urls:
content = await perform_scrape(url)
if not content:
continue # Skip to the next URL if no content is found
logger.info(f"Content scraped successfully: {content}")
# Step 3: Summarize the scraped content based on the search term
summary = await perform_summarization(search_term, content, llm, model)
if not summary:
continue # Skip if no summary was generated

# Step 4: Validate the summarization
is_valid = await is_valid_answer(summary, search_term)
if not is_valid:
continue # Skip if the summarization is not valid
response = {"content": {"content": summary, "url": url}, "ignore_validation": "false"}
return json.dumps(response, indent=4)
return "No relevant information found on the internet for the given query."
except Exception as e:
logger.error(f"Error in web_general_search_core: {e}")
return "An error occurred while processing the search query."
Expand Down Expand Up @@ -148,10 +125,7 @@ async def web_scrape_core(url: str) -> str:
return "No content found at the provided URL."
logger.info(f"Content scraped successfully: {content}")
content = content.replace("\n", " ").replace("\r", " ")
response = {
"content": { "content": content, "url": url },
"ignore_validation": "true"
}
response = {"content": {"content": content, "url": url}, "ignore_validation": "true"}
return json.dumps(response, indent=4)
except Exception as e:
return json.dumps({"status": "error", "error": str(e)})
Expand Down Expand Up @@ -216,15 +190,6 @@ async def is_valid_answer(answer, task) -> bool:
return is_valid


async def perform_search(search_query: str, num_results: int) -> Dict[str, Any]:
try:
search_result_json = await search_urls(search_query, num_results=num_results)
return json.loads(search_result_json)
except Exception as e:
logger.error(f"Error during web search: {e}")
return {"status": "error", "urls": []}


async def perform_scrape(url: str) -> str:
try:
if not str(url).startswith("https"):
Expand Down
47 changes: 0 additions & 47 deletions backend/src/prompts/templates/answer-user-question.j2

This file was deleted.

4 changes: 4 additions & 0 deletions backend/src/prompts/templates/generalist-answer.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
You are an expert in providing accurate and complete answers to user queries regarding ESG. Your task is to **Generate a detailed answer** to the user's question.

User's question is:
{{ question }}
35 changes: 18 additions & 17 deletions backend/src/supervisors/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from src.utils import get_scratchpad, update_scratchpad
from src.router import get_agent_for_task
from src.agents import get_validator_agent
from src.agents import get_validator_agent, get_generalist_agent
import json

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -34,22 +34,23 @@ async def solve_all(intent_json) -> None:


async def solve_task(task, scratchpad, attempt=0) -> Tuple[str, str, str]:
if attempt == 5:
raise Exception(unsolvable_response)

agent = await get_agent_for_task(task, scratchpad)
if agent is None:
raise Exception(no_agent_response)
logger.info(f"Agent selected: {agent.name}")
logger.info(f"Task is {task}")
answer = await agent.invoke(task)
parsed_json = json.loads(answer)
status = parsed_json.get("status", "success")
ignore_validation = parsed_json.get("ignore_validation", "")
answer_content = parsed_json.get("content", "")
if (ignore_validation == "true") or await is_valid_answer(answer_content, task):
return (agent.name, answer_content, status)
return await solve_task(task, scratchpad, attempt + 1)
for attempt in [1, 2, 3, 4]:
if attempt == 4:
agent = get_generalist_agent()
else:
agent = await get_agent_for_task(task, scratchpad)
if agent is None:
raise Exception(no_agent_response)
logger.info(f"Agent selected: {agent.name}")
logger.info(f"Task is {task}")
answer = await agent.invoke(task)
parsed_json = json.loads(answer)
status = parsed_json.get("status", "success")
ignore_validation = parsed_json.get("ignore_validation", "")
answer_content = parsed_json.get("content", "")
if (ignore_validation == "true") or await is_valid_answer(answer_content, task):
return (agent.name, answer_content, status)
raise Exception(unsolvable_response)


async def is_valid_answer(answer, task) -> bool:
Expand Down
23 changes: 3 additions & 20 deletions backend/src/utils/web_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ async def scrape_content(url, limit=100000) -> str:
}
)


async def create_search_term(search_query, llm, model) -> str:
try:
summariser_prompt = engine.load_prompt("create-search-term", question=search_query)
Expand All @@ -81,26 +82,6 @@ async def create_search_term(search_query, llm, model) -> str:
}
)

async def answer_user_question(search_query, llm, model) -> str:
try:
summariser_prompt = engine.load_prompt("answer-user-question", question=search_query)
response = await llm.chat(model, summariser_prompt, "", return_json=True)
return json.dumps(
{
"status": "success",
"response": response,
"error": None,
}
)
except Exception as e:
logger.error(f"Error during create search term: {e}")
return json.dumps(
{
"status": "error",
"response": None,
"error": str(e),
}
)

async def summarise_content(search_query, contents, llm, model) -> str:
try:
Expand All @@ -123,6 +104,7 @@ async def summarise_content(search_query, contents, llm, model) -> str:
}
)


async def summarise_pdf_content(contents, llm, model) -> str:
try:
summariser_prompt = engine.load_prompt("pdf-summariser", content=contents)
Expand All @@ -144,6 +126,7 @@ async def summarise_pdf_content(contents, llm, model) -> str:
}
)


async def perform_math_operation_util(math_query, llm, model) -> str:
try:
math_prompt = engine.load_prompt("math-solver", query=math_query)
Expand Down
35 changes: 35 additions & 0 deletions backend/tests/agents/generalist_agent_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pytest
from unittest.mock import patch, AsyncMock
import json
from src.agents.generalist_agent import GeneralistAgent


@pytest.mark.asyncio
@patch("src.agents.generalist_agent.answer_user_question", new_callable=AsyncMock)
async def test_generalist_agent(
mock_answer_user_question,
):
mock_answer_user_question.return_value = json.dumps(
{"status": "success", "response": json.dumps({"is_valid": True, "answer": "Example summary."})}
)
generalist_agent = GeneralistAgent("llm", "mock_model")

result = await generalist_agent.invoke("example query")
expected_response = {"content": "Example summary.", "ignore_validation": "false"}
assert json.loads(result) == expected_response


@pytest.mark.asyncio
@patch("src.agents.generalist_agent.answer_user_question", new_callable=AsyncMock)
async def test_generalist_agent_reponse_format_error(
mock_answer_user_question,
):
mock_answer_user_question.return_value = json.dumps(
{"status": "success", "response": json.dumps({"is_valid": True, "answer_wrong_format": "Example summary."})}
)
generalist_agent = GeneralistAgent("llm", "mock_model")

result = await generalist_agent.invoke("example query")

expected_response = {"content": "Error in answer format.", "ignore_validation": "false"}
assert json.loads(result) == expected_response
Loading

0 comments on commit 72d8eaa

Please sign in to comment.