From 5ad8df30c49a3b89b1c949fab8cedabf36279ec6 Mon Sep 17 00:00:00 2001 From: Emma Pearce Date: Tue, 17 Dec 2024 10:32:12 +0000 Subject: [PATCH] FS-124 Add generalist agent and update supervisior to call this as a fall back --- backend/src/agents/__init__.py | 5 ++ backend/src/agents/generalist_agent.py | 39 ++++++++ backend/src/agents/web_agent.py | 89 +++++++------------ .../prompts/templates/answer-user-question.j2 | 45 +--------- backend/src/supervisors/supervisor.py | 10 +-- backend/src/utils/web_utils.py | 7 +- backend/tests/agents/generalist_agent_test.py | 45 ++++++++++ backend/tests/agents/web_agent_test.py | 89 ++----------------- 8 files changed, 139 insertions(+), 190 deletions(-) create mode 100644 backend/src/agents/generalist_agent.py create mode 100644 backend/tests/agents/generalist_agent_test.py diff --git a/backend/src/agents/__init__.py b/backend/src/agents/__init__.py index 3a973110..8a8963e7 100644 --- a/backend/src/agents/__init__.py +++ b/backend/src/agents/__init__.py @@ -11,6 +11,7 @@ from src.agents.chart_generator_agent import ChartGeneratorAgent from src.agents.report_agent import ReportAgent from src.agents.materiality_agent import MaterialityAgent +from src.agents.generalist_agent import GeneralistAgent config = Config() @@ -36,6 +37,10 @@ def get_materiality_agent() -> MaterialityAgent: return MaterialityAgent(config.materiality_agent_llm, config.materiality_agent_model) +def get_generalist_agent() -> GeneralistAgent: + return GeneralistAgent(config.intent_agent_llm, config.intent_agent_model) + + def agent_details(agent: ChatAgent) -> dict: return {"name": agent.name, "description": agent.description} diff --git a/backend/src/agents/generalist_agent.py b/backend/src/agents/generalist_agent.py new file mode 100644 index 00000000..14f19185 --- /dev/null +++ b/backend/src/agents/generalist_agent.py @@ -0,0 +1,39 @@ +import logging +from src.prompts import PromptEngine +from src.agents import ChatAgent, chat_agent +from src.utils import Config +from src.utils.web_utils import ( + answer_user_question, +) +import json + +logger = logging.getLogger(__name__) +config = Config() + +engine = PromptEngine() + + +@chat_agent( + name="GeneralistAgent", + description="This agent attempts to answer a general question using only the llm", + tools=[], +) +class GeneralistAgent(ChatAgent): + async def invoke(self, utterance) -> str: + try: + answer_to_user = await answer_user_question(utterance, self.llm, self.model) + answer_result = json.loads(answer_to_user) + if answer_result["status"] == "error": + response = {"content": "Error in finding the answer.", "ignore_validation": "false"} + return json.dumps(response, indent=4) + final_answer = json.loads(answer_result["response"]).get("answer", "") + if not final_answer: + response = {"content": "Error in answer format.", "ignore_validation": "false"} + return json.dumps(response, indent=4) + logger.info(f"Answer found successfully {final_answer}") + response = {"content": final_answer, "ignore_validation": "false"} + return json.dumps(response, indent=4) + + except Exception as e: + logger.error(f"Error in web_general_search_core: {e}") + return "An error occurred while processing the search query." diff --git a/backend/src/agents/web_agent.py b/backend/src/agents/web_agent.py index d4b6005a..5c389e97 100644 --- a/backend/src/agents/web_agent.py +++ b/backend/src/agents/web_agent.py @@ -11,7 +11,6 @@ summarise_pdf_content, find_info, create_search_term, - answer_user_question, ) from .validator_agent import ValidatorAgent import aiohttp @@ -28,59 +27,40 @@ async def web_general_search_core(search_query, llm, model) -> str: try: - # Step 1: Generate the search term from the user's query - answer_to_user = await answer_user_question(search_query, llm, model) - answer_result = json.loads(answer_to_user) - if answer_result["status"] == "error": - response = {"content": "Error in finding the answer.", "ignore_validation": "false"} - return json.dumps(response, indent=4) - logger.info(f"Answer found successfully {answer_result}") - should_perform_web_search = json.loads(answer_result["response"]).get("should_perform_web_search", "") - if not should_perform_web_search: - final_answer = json.loads(answer_result["response"]).get("answer", "") - if not final_answer: - return "No answer found." - logger.info(f"Answer found successfully {final_answer}") - response = {"content": final_answer, "ignore_validation": "false"} + search_term_json = await create_search_term(search_query, llm, model) + search_term_result = json.loads(search_term_json) + + # Step 1: Check if there was an error in generating the search term + if search_term_result.get("status") == "error": + response = {"content": search_term_result.get("error"), "ignore_validation": "false"} return json.dumps(response, indent=4) - else: - search_term_json = await create_search_term(search_query, llm, model) - search_term_result = json.loads(search_term_json) - - # Check if there was an error in generating the search term - if search_term_result.get("status") == "error": - response = {"content": search_term_result.get("error"), "ignore_validation": "false"} - return json.dumps(response, indent=4) - search_term = json.loads(search_term_result["response"]).get("search_term", "") - - # Step 2: Perform the search using the generated search term - search_result = await perform_search(search_term, num_results=15) - if search_result.get("status") == "error": - return "No relevant information found on the internet for the given query." - urls = search_result.get("urls", []) - logger.info(f"URLs found: {urls}") - - # Step 3: Scrape content from the URLs found - for url in urls: - content = await perform_scrape(url) - if not content: - continue # Skip to the next URL if no content is found - # logger.info(f"Content scraped successfully: {content}") - # Step 4: Summarize the scraped content based on the search term - summary = await perform_summarization(search_term, content, llm, model) - if not summary: - continue # Skip if no summary was generated - - # Step 5: Validate the summarization - is_valid = await is_valid_answer(summary, search_term) - if not is_valid: - continue # Skip if the summarization is not valid - response = { - "content": { "content": summary, "url": url }, - "ignore_validation": "false" - } - return json.dumps(response, indent=4) + search_term = json.loads(search_term_result["response"]).get("search_term", "") + + # Step 2: Perform the search using the generated search term + search_result = await perform_search(search_term, num_results=15) + if search_result.get("status") == "error": return "No relevant information found on the internet for the given query." + urls = search_result.get("urls", []) + logger.info(f"URLs found: {urls}") + + # Step 3: Scrape content from the URLs found + for url in urls: + content = await perform_scrape(url) + if not content: + continue # Skip to the next URL if no content is found + # logger.info(f"Content scraped successfully: {content}") + # Step 4: Summarize the scraped content based on the search term + summary = await perform_summarization(search_term, content, llm, model) + if not summary: + continue # Skip if no summary was generated + + # Step 5: Validate the summarization + is_valid = await is_valid_answer(summary, search_term) + if not is_valid: + continue # Skip if the summarization is not valid + response = {"content": {"content": summary, "url": url}, "ignore_validation": "false"} + return json.dumps(response, indent=4) + return "No relevant information found on the internet for the given query." except Exception as e: logger.error(f"Error in web_general_search_core: {e}") return "An error occurred while processing the search query." @@ -148,10 +128,7 @@ async def web_scrape_core(url: str) -> str: return "No content found at the provided URL." logger.info(f"Content scraped successfully: {content}") content = content.replace("\n", " ").replace("\r", " ") - response = { - "content": { "content": content, "url": url }, - "ignore_validation": "true" - } + response = {"content": {"content": content, "url": url}, "ignore_validation": "true"} return json.dumps(response, indent=4) except Exception as e: return json.dumps({"status": "error", "error": str(e)}) diff --git a/backend/src/prompts/templates/answer-user-question.j2 b/backend/src/prompts/templates/answer-user-question.j2 index 1f0966a0..56a291f2 100644 --- a/backend/src/prompts/templates/answer-user-question.j2 +++ b/backend/src/prompts/templates/answer-user-question.j2 @@ -1,47 +1,4 @@ -You are an expert in providing accurate and complete answers to user queries. Your task is twofold: - -1. **Generate a detailed answer** to the user's question based on the provided content or context. -2. **Validate** if the generated answer directly addresses the user's question and is factually accurate. +You are an expert in providing accurate and complete answers to user queries. Your task is to **Generate a detailed answer** to the user's question. User's question is: {{ question }} - -Once you generate an answer: -- **Check** if the answer completely and accurately addresses the user's question. -- **Determine** if the answer is valid, based on the content provided. -- **Provide** a reason for the validity or invalidity of the answer. -- **Indicate** if a more general web search is required then indicate that. - -Reply only in JSON format with the following structure: - -```json -{ - "answer": "The answer to the user's question, based on the content provided", - "should_perform_web_search": true or false, - "perform_web_search_reason": "A sentence explaining whether the answer is valid or not, and why" -} - -### **Example of Usage:** - -#### **User’s Question:** -- **"What is Tesla's revenue every year since its creation?"** - -#### **Content Provided:** -- A table or a paragraph with data on Tesla's revenue for various years. - -#### **LLM’s Response:** - -```json -{ - "answer": "Tesla's revenue since its creation is: 2008: $15 million, 2009: $30 million, ..., 2023: $81 billion.", - "should_perform_web_search": false, - "perform_web_search_reason": "The answer includes Tesla's revenue for every year since its creation, based on the data provided. No further search is required." -} - -{ - "answer": "Tesla's revenue for 2010 to 2023 is available, but data for the earlier years is missing.", - "should_perform_web_search": true, - "perform_web_search_reason": "The answer is incomplete because data for Tesla's early years is missing. A more general web search is required." -} - -Important: If the question is related to real time data, the LLM should provide should_perform_web_search is true. diff --git a/backend/src/supervisors/supervisor.py b/backend/src/supervisors/supervisor.py index 9817bfeb..c9fd6e7a 100644 --- a/backend/src/supervisors/supervisor.py +++ b/backend/src/supervisors/supervisor.py @@ -2,7 +2,7 @@ import logging from src.utils import get_scratchpad, update_scratchpad from src.router import get_agent_for_task -from src.agents import get_validator_agent +from src.agents import get_validator_agent, get_generalist_agent import json logger = logging.getLogger(__name__) @@ -34,10 +34,10 @@ async def solve_all(intent_json) -> None: async def solve_task(task, scratchpad, attempt=0) -> Tuple[str, str, str]: - if attempt == 5: - raise Exception(unsolvable_response) - - agent = await get_agent_for_task(task, scratchpad) + if attempt == 3: + agent = get_generalist_agent() + else: + agent = await get_agent_for_task(task, scratchpad) if agent is None: raise Exception(no_agent_response) logger.info(f"Agent selected: {agent.name}") diff --git a/backend/src/utils/web_utils.py b/backend/src/utils/web_utils.py index e83d43ee..e2997acc 100644 --- a/backend/src/utils/web_utils.py +++ b/backend/src/utils/web_utils.py @@ -60,6 +60,7 @@ async def scrape_content(url, limit=100000) -> str: } ) + async def create_search_term(search_query, llm, model) -> str: try: summariser_prompt = engine.load_prompt("create-search-term", question=search_query) @@ -81,10 +82,11 @@ async def create_search_term(search_query, llm, model) -> str: } ) + async def answer_user_question(search_query, llm, model) -> str: try: summariser_prompt = engine.load_prompt("answer-user-question", question=search_query) - response = await llm.chat(model, summariser_prompt, "", return_json=True) + response = await llm.chat(model, summariser_prompt, "") return json.dumps( { "status": "success", @@ -102,6 +104,7 @@ async def answer_user_question(search_query, llm, model) -> str: } ) + async def summarise_content(search_query, contents, llm, model) -> str: try: summariser_prompt = engine.load_prompt("summariser", question=search_query, content=contents) @@ -123,6 +126,7 @@ async def summarise_content(search_query, contents, llm, model) -> str: } ) + async def summarise_pdf_content(contents, llm, model) -> str: try: summariser_prompt = engine.load_prompt("pdf-summariser", content=contents) @@ -144,6 +148,7 @@ async def summarise_pdf_content(contents, llm, model) -> str: } ) + async def perform_math_operation_util(math_query, llm, model) -> str: try: math_prompt = engine.load_prompt("math-solver", query=math_query) diff --git a/backend/tests/agents/generalist_agent_test.py b/backend/tests/agents/generalist_agent_test.py new file mode 100644 index 00000000..adfe45f0 --- /dev/null +++ b/backend/tests/agents/generalist_agent_test.py @@ -0,0 +1,45 @@ +import pytest +from unittest.mock import patch, AsyncMock +import json +from src.agents.generalist_agent import GeneralistAgent + + +@pytest.mark.asyncio +@patch("src.agents.generalist_agent.answer_user_question", new_callable=AsyncMock) +async def test_generalist_agent( + mock_answer_user_question, +): + mock_answer_user_question.return_value = json.dumps( + {"status": "success", "response": json.dumps({"is_valid": True, "answer": "Example summary."})} + ) + generalist_agent = GeneralistAgent("llm", "mock_model") + + result = await generalist_agent.invoke("example query") + expected_response = {"content": "Example summary.", "ignore_validation": "false"} + assert json.loads(result) == expected_response + + +@pytest.mark.asyncio +async def test_generalist_agent_response_error(): + generalist_agent = GeneralistAgent("llm", "mock_model") + + result = await generalist_agent.invoke("example query") + + expected_response = {"content": "Error in finding the answer.", "ignore_validation": "false"} + assert json.loads(result) == expected_response + + +@pytest.mark.asyncio +@patch("src.agents.generalist_agent.answer_user_question", new_callable=AsyncMock) +async def test_generalist_agent_reponse_format_error( + mock_answer_user_question, +): + mock_answer_user_question.return_value = json.dumps( + {"status": "success", "response": json.dumps({"is_valid": True, "answer_wrong_format": "Example summary."})} + ) + generalist_agent = GeneralistAgent("llm", "mock_model") + + result = await generalist_agent.invoke("example query") + + expected_response = {"content": "Error in answer format.", "ignore_validation": "false"} + assert json.loads(result) == expected_response diff --git a/backend/tests/agents/web_agent_test.py b/backend/tests/agents/web_agent_test.py index a64a9867..fbd9e53e 100644 --- a/backend/tests/agents/web_agent_test.py +++ b/backend/tests/agents/web_agent_test.py @@ -1,104 +1,24 @@ import pytest from unittest.mock import patch, AsyncMock import json -from src.agents.web_agent import web_general_search_core, perform_scrape +from src.agents.web_agent import perform_scrape from src.utils.web_utils import search_urls -@pytest.mark.asyncio -@patch("src.agents.web_agent.answer_user_question", new_callable=AsyncMock) -@patch("src.agents.web_agent.create_search_term", new_callable=AsyncMock) -@patch("src.agents.web_agent.perform_search", new_callable=AsyncMock) -@patch("src.agents.web_agent.perform_scrape", new_callable=AsyncMock) -@patch("src.agents.web_agent.perform_summarization", new_callable=AsyncMock) -@patch("src.agents.web_agent.is_valid_answer", new_callable=AsyncMock) -async def test_web_general_search_core( - mock_is_valid_answer, - mock_perform_summarization, - mock_perform_scrape, - mock_perform_search, - mock_create_search_term, - mock_answer_user_question -): - llm = AsyncMock() - model = "mock_model" - - # Mocking answer_user_question to return a valid answer - mock_answer_user_question.return_value = json.dumps({ - "status": "success", - "response": json.dumps({"is_valid": True, "answer": "Example summary."}) - }) - - result = await web_general_search_core("example query", llm, model) - expected_response = { - "content": "Example summary.", - "ignore_validation": "false" - } - assert json.loads(result) == expected_response - - -@pytest.mark.asyncio -@patch("src.agents.web_agent.perform_search", new_callable=AsyncMock) -@patch("src.agents.web_agent.perform_scrape", new_callable=AsyncMock) -@patch("src.agents.web_agent.perform_summarization", new_callable=AsyncMock) -@patch("src.agents.web_agent.is_valid_answer", new_callable=AsyncMock) -async def test_web_general_search_core_no_results( - mock_is_valid_answer, - mock_perform_summarization, - mock_perform_scrape, - mock_perform_search, -): - llm = AsyncMock() - model = "mock_model" - mock_perform_search.return_value = {"status": "error", "urls": []} - result = await web_general_search_core("example query", llm, model) - - expected_response = { - "content": "Error in finding the answer.", - "ignore_validation": "false" - } - assert json.loads(result) == expected_response - -@pytest.mark.asyncio -@patch("src.agents.web_agent.perform_search", new_callable=AsyncMock) -@patch("src.agents.web_agent.perform_scrape", new_callable=AsyncMock) -@patch("src.agents.web_agent.perform_summarization", new_callable=AsyncMock) -@patch("src.agents.web_agent.is_valid_answer", new_callable=AsyncMock) -async def test_web_general_search_core_invalid_summary( - mock_is_valid_answer, - mock_perform_summarization, - mock_perform_scrape, - mock_perform_search -): - llm = AsyncMock() - model = "mock_model" - mock_perform_search.return_value = {"status": "success", "urls": ["https://example.com"], "error": None} - mock_perform_scrape.return_value = "Example scraped content." - mock_perform_summarization.return_value = json.dumps({"summary": "Example invalid summary."}) - mock_is_valid_answer.return_value = False - result = await web_general_search_core("example query", llm, model) - expected_response = { - "content": "Error in finding the answer.", - "ignore_validation": "false" - } - assert json.loads(result) == expected_response @pytest.mark.asyncio @patch("src.utils.web_utils.search") async def test_https_urls(mock_search): - mock_search.return_value = [ - "https://example.com", - "http://nonsecure.com", - "https://another-secure-site.com" - ] + mock_search.return_value = ["https://example.com", "http://nonsecure.com", "https://another-secure-site.com"] result = await search_urls("query", num_results=5) expected_result = { "status": "success", "urls": ["https://example.com", "https://another-secure-site.com"], - "error": None + "error": None, } assert json.loads(result) == expected_result + @pytest.mark.asyncio @patch("src.agents.web_agent.scrape_content", new_callable=AsyncMock) async def test_perform_scrape_http_url(mock_scrape_content): @@ -107,6 +27,7 @@ async def test_perform_scrape_http_url(mock_scrape_content): result = await perform_scrape("http://nonsecure.com") assert result == "" + @pytest.mark.asyncio @patch("src.agents.web_agent.scrape_content", new_callable=AsyncMock) async def test_perform_scrape_https_url(mock_scrape_content):