From ea221b1de20b7f9e8c465d02d7447f2360ca1020 Mon Sep 17 00:00:00 2001 From: Emma Pearce Date: Wed, 18 Dec 2024 16:57:47 +0000 Subject: [PATCH] address Ivans comments --- backend/src/agents/generalist_agent.py | 28 ++++++++++++---- backend/src/agents/web_agent.py | 30 +++++------------ ...-user-question.j2 => generalist-answer.j2} | 2 +- backend/src/supervisors/supervisor.py | 33 ++++++++++--------- backend/src/utils/web_utils.py | 22 ------------- backend/tests/agents/generalist_agent_test.py | 10 ------ 6 files changed, 49 insertions(+), 76 deletions(-) rename backend/src/prompts/templates/{answer-user-question.j2 => generalist-answer.j2} (50%) diff --git a/backend/src/agents/generalist_agent.py b/backend/src/agents/generalist_agent.py index 14f19185..afa72e5f 100644 --- a/backend/src/agents/generalist_agent.py +++ b/backend/src/agents/generalist_agent.py @@ -2,9 +2,6 @@ from src.prompts import PromptEngine from src.agents import ChatAgent, chat_agent from src.utils import Config -from src.utils.web_utils import ( - answer_user_question, -) import json logger = logging.getLogger(__name__) @@ -23,9 +20,6 @@ async def invoke(self, utterance) -> str: try: answer_to_user = await answer_user_question(utterance, self.llm, self.model) answer_result = json.loads(answer_to_user) - if answer_result["status"] == "error": - response = {"content": "Error in finding the answer.", "ignore_validation": "false"} - return json.dumps(response, indent=4) final_answer = json.loads(answer_result["response"]).get("answer", "") if not final_answer: response = {"content": "Error in answer format.", "ignore_validation": "false"} @@ -37,3 +31,25 @@ async def invoke(self, utterance) -> str: except Exception as e: logger.error(f"Error in web_general_search_core: {e}") return "An error occurred while processing the search query." + + +async def answer_user_question(search_query, llm, model) -> str: + try: + summariser_prompt = engine.load_prompt("generalist-answer", question=search_query) + response = await llm.chat(model, summariser_prompt, "") + return json.dumps( + { + "status": "success", + "response": response, + "error": None, + } + ) + except Exception as e: + logger.error(f"Error during create search term: {e}") + return json.dumps( + { + "status": "error", + "response": None, + "error": str(e), + } + ) diff --git a/backend/src/agents/web_agent.py b/backend/src/agents/web_agent.py index 5c389e97..2d15d979 100644 --- a/backend/src/agents/web_agent.py +++ b/backend/src/agents/web_agent.py @@ -17,7 +17,7 @@ import io from pypdf import PdfReader import json -from typing import Dict, Any +from typing import Any logger = logging.getLogger(__name__) config = Config() @@ -29,32 +29,29 @@ async def web_general_search_core(search_query, llm, model) -> str: try: search_term_json = await create_search_term(search_query, llm, model) search_term_result = json.loads(search_term_json) - - # Step 1: Check if there was an error in generating the search term - if search_term_result.get("status") == "error": - response = {"content": search_term_result.get("error"), "ignore_validation": "false"} - return json.dumps(response, indent=4) search_term = json.loads(search_term_result["response"]).get("search_term", "") - # Step 2: Perform the search using the generated search term - search_result = await perform_search(search_term, num_results=15) + # Step 1: Perform the search using the generated search term + search_result_json = await search_urls(search_query, num_results=15) + search_result = json.loads(search_result_json) + if search_result.get("status") == "error": return "No relevant information found on the internet for the given query." urls = search_result.get("urls", []) logger.info(f"URLs found: {urls}") - # Step 3: Scrape content from the URLs found + # Step 2: Scrape content from the URLs found for url in urls: content = await perform_scrape(url) if not content: continue # Skip to the next URL if no content is found - # logger.info(f"Content scraped successfully: {content}") - # Step 4: Summarize the scraped content based on the search term + logger.info(f"Content scraped successfully: {content}") + # Step 3: Summarize the scraped content based on the search term summary = await perform_summarization(search_term, content, llm, model) if not summary: continue # Skip if no summary was generated - # Step 5: Validate the summarization + # Step 4: Validate the summarization is_valid = await is_valid_answer(summary, search_term) if not is_valid: continue # Skip if the summarization is not valid @@ -193,15 +190,6 @@ async def is_valid_answer(answer, task) -> bool: return is_valid -async def perform_search(search_query: str, num_results: int) -> Dict[str, Any]: - try: - search_result_json = await search_urls(search_query, num_results=num_results) - return json.loads(search_result_json) - except Exception as e: - logger.error(f"Error during web search: {e}") - return {"status": "error", "urls": []} - - async def perform_scrape(url: str) -> str: try: if not str(url).startswith("https"): diff --git a/backend/src/prompts/templates/answer-user-question.j2 b/backend/src/prompts/templates/generalist-answer.j2 similarity index 50% rename from backend/src/prompts/templates/answer-user-question.j2 rename to backend/src/prompts/templates/generalist-answer.j2 index 56a291f2..81bb609a 100644 --- a/backend/src/prompts/templates/answer-user-question.j2 +++ b/backend/src/prompts/templates/generalist-answer.j2 @@ -1,4 +1,4 @@ -You are an expert in providing accurate and complete answers to user queries. Your task is to **Generate a detailed answer** to the user's question. +You are an expert in providing accurate and complete answers to user queries regarding ESG. Your task is to **Generate a detailed answer** to the user's question. User's question is: {{ question }} diff --git a/backend/src/supervisors/supervisor.py b/backend/src/supervisors/supervisor.py index c9fd6e7a..201b735d 100644 --- a/backend/src/supervisors/supervisor.py +++ b/backend/src/supervisors/supervisor.py @@ -34,22 +34,23 @@ async def solve_all(intent_json) -> None: async def solve_task(task, scratchpad, attempt=0) -> Tuple[str, str, str]: - if attempt == 3: - agent = get_generalist_agent() - else: - agent = await get_agent_for_task(task, scratchpad) - if agent is None: - raise Exception(no_agent_response) - logger.info(f"Agent selected: {agent.name}") - logger.info(f"Task is {task}") - answer = await agent.invoke(task) - parsed_json = json.loads(answer) - status = parsed_json.get("status", "success") - ignore_validation = parsed_json.get("ignore_validation", "") - answer_content = parsed_json.get("content", "") - if (ignore_validation == "true") or await is_valid_answer(answer_content, task): - return (agent.name, answer_content, status) - return await solve_task(task, scratchpad, attempt + 1) + for attempt in [1, 2, 3, 4]: + if attempt == 4: + agent = get_generalist_agent() + else: + agent = await get_agent_for_task(task, scratchpad) + if agent is None: + raise Exception(no_agent_response) + logger.info(f"Agent selected: {agent.name}") + logger.info(f"Task is {task}") + answer = await agent.invoke(task) + parsed_json = json.loads(answer) + status = parsed_json.get("status", "success") + ignore_validation = parsed_json.get("ignore_validation", "") + answer_content = parsed_json.get("content", "") + if (ignore_validation == "true") or await is_valid_answer(answer_content, task): + return (agent.name, answer_content, status) + raise Exception(unsolvable_response) async def is_valid_answer(answer, task) -> bool: diff --git a/backend/src/utils/web_utils.py b/backend/src/utils/web_utils.py index e2997acc..8d90dabf 100644 --- a/backend/src/utils/web_utils.py +++ b/backend/src/utils/web_utils.py @@ -83,28 +83,6 @@ async def create_search_term(search_query, llm, model) -> str: ) -async def answer_user_question(search_query, llm, model) -> str: - try: - summariser_prompt = engine.load_prompt("answer-user-question", question=search_query) - response = await llm.chat(model, summariser_prompt, "") - return json.dumps( - { - "status": "success", - "response": response, - "error": None, - } - ) - except Exception as e: - logger.error(f"Error during create search term: {e}") - return json.dumps( - { - "status": "error", - "response": None, - "error": str(e), - } - ) - - async def summarise_content(search_query, contents, llm, model) -> str: try: summariser_prompt = engine.load_prompt("summariser", question=search_query, content=contents) diff --git a/backend/tests/agents/generalist_agent_test.py b/backend/tests/agents/generalist_agent_test.py index adfe45f0..2d67890f 100644 --- a/backend/tests/agents/generalist_agent_test.py +++ b/backend/tests/agents/generalist_agent_test.py @@ -19,16 +19,6 @@ async def test_generalist_agent( assert json.loads(result) == expected_response -@pytest.mark.asyncio -async def test_generalist_agent_response_error(): - generalist_agent = GeneralistAgent("llm", "mock_model") - - result = await generalist_agent.invoke("example query") - - expected_response = {"content": "Error in finding the answer.", "ignore_validation": "false"} - assert json.loads(result) == expected_response - - @pytest.mark.asyncio @patch("src.agents.generalist_agent.answer_user_question", new_callable=AsyncMock) async def test_generalist_agent_reponse_format_error(