forked from ScottLogic/InferLLM
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
FS-124 Add generalist agent and update supervisior to call this as a …
…fall back
- Loading branch information
Showing
8 changed files
with
139 additions
and
190 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import logging | ||
from src.prompts import PromptEngine | ||
from src.agents import ChatAgent, chat_agent | ||
from src.utils import Config | ||
from src.utils.web_utils import ( | ||
answer_user_question, | ||
) | ||
import json | ||
|
||
logger = logging.getLogger(__name__) | ||
config = Config() | ||
|
||
engine = PromptEngine() | ||
|
||
|
||
@chat_agent( | ||
name="GeneralistAgent", | ||
description="This agent attempts to answer a general question using only the llm", | ||
tools=[], | ||
) | ||
class GeneralistAgent(ChatAgent): | ||
async def invoke(self, utterance) -> str: | ||
try: | ||
answer_to_user = await answer_user_question(utterance, self.llm, self.model) | ||
answer_result = json.loads(answer_to_user) | ||
if answer_result["status"] == "error": | ||
response = {"content": "Error in finding the answer.", "ignore_validation": "false"} | ||
return json.dumps(response, indent=4) | ||
final_answer = json.loads(answer_result["response"]).get("answer", "") | ||
if not final_answer: | ||
response = {"content": "Error in answer format.", "ignore_validation": "false"} | ||
return json.dumps(response, indent=4) | ||
logger.info(f"Answer found successfully {final_answer}") | ||
response = {"content": final_answer, "ignore_validation": "false"} | ||
return json.dumps(response, indent=4) | ||
|
||
except Exception as e: | ||
logger.error(f"Error in web_general_search_core: {e}") | ||
return "An error occurred while processing the search query." |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,47 +1,4 @@ | ||
You are an expert in providing accurate and complete answers to user queries. Your task is twofold: | ||
|
||
1. **Generate a detailed answer** to the user's question based on the provided content or context. | ||
2. **Validate** if the generated answer directly addresses the user's question and is factually accurate. | ||
You are an expert in providing accurate and complete answers to user queries. Your task is to **Generate a detailed answer** to the user's question. | ||
|
||
User's question is: | ||
{{ question }} | ||
|
||
Once you generate an answer: | ||
- **Check** if the answer completely and accurately addresses the user's question. | ||
- **Determine** if the answer is valid, based on the content provided. | ||
- **Provide** a reason for the validity or invalidity of the answer. | ||
- **Indicate** if a more general web search is required then indicate that. | ||
|
||
Reply only in JSON format with the following structure: | ||
|
||
```json | ||
{ | ||
"answer": "The answer to the user's question, based on the content provided", | ||
"should_perform_web_search": true or false, | ||
"perform_web_search_reason": "A sentence explaining whether the answer is valid or not, and why" | ||
} | ||
|
||
### **Example of Usage:** | ||
|
||
#### **User’s Question:** | ||
- **"What is Tesla's revenue every year since its creation?"** | ||
|
||
#### **Content Provided:** | ||
- A table or a paragraph with data on Tesla's revenue for various years. | ||
|
||
#### **LLM’s Response:** | ||
|
||
```json | ||
{ | ||
"answer": "Tesla's revenue since its creation is: 2008: $15 million, 2009: $30 million, ..., 2023: $81 billion.", | ||
"should_perform_web_search": false, | ||
"perform_web_search_reason": "The answer includes Tesla's revenue for every year since its creation, based on the data provided. No further search is required." | ||
} | ||
|
||
{ | ||
"answer": "Tesla's revenue for 2010 to 2023 is available, but data for the earlier years is missing.", | ||
"should_perform_web_search": true, | ||
"perform_web_search_reason": "The answer is incomplete because data for Tesla's early years is missing. A more general web search is required." | ||
} | ||
|
||
Important: If the question is related to real time data, the LLM should provide should_perform_web_search is true. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import pytest | ||
from unittest.mock import patch, AsyncMock | ||
import json | ||
from src.agents.generalist_agent import GeneralistAgent | ||
|
||
|
||
@pytest.mark.asyncio | ||
@patch("src.agents.generalist_agent.answer_user_question", new_callable=AsyncMock) | ||
async def test_generalist_agent( | ||
mock_answer_user_question, | ||
): | ||
mock_answer_user_question.return_value = json.dumps( | ||
{"status": "success", "response": json.dumps({"is_valid": True, "answer": "Example summary."})} | ||
) | ||
generalist_agent = GeneralistAgent("llm", "mock_model") | ||
|
||
result = await generalist_agent.invoke("example query") | ||
expected_response = {"content": "Example summary.", "ignore_validation": "false"} | ||
assert json.loads(result) == expected_response | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_generalist_agent_response_error(): | ||
generalist_agent = GeneralistAgent("llm", "mock_model") | ||
|
||
result = await generalist_agent.invoke("example query") | ||
|
||
expected_response = {"content": "Error in finding the answer.", "ignore_validation": "false"} | ||
assert json.loads(result) == expected_response | ||
|
||
|
||
@pytest.mark.asyncio | ||
@patch("src.agents.generalist_agent.answer_user_question", new_callable=AsyncMock) | ||
async def test_generalist_agent_reponse_format_error( | ||
mock_answer_user_question, | ||
): | ||
mock_answer_user_question.return_value = json.dumps( | ||
{"status": "success", "response": json.dumps({"is_valid": True, "answer_wrong_format": "Example summary."})} | ||
) | ||
generalist_agent = GeneralistAgent("llm", "mock_model") | ||
|
||
result = await generalist_agent.invoke("example query") | ||
|
||
expected_response = {"content": "Error in answer format.", "ignore_validation": "false"} | ||
assert json.loads(result) == expected_response |
Oops, something went wrong.