forked from ScottLogic/InferLLM
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
FS-124 Add generalist agent and update supervisor to call this as a f…
…all back (#55)
- Loading branch information
Showing
9 changed files
with
157 additions
and
235 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import logging | ||
from src.prompts import PromptEngine | ||
from src.agents import ChatAgent, chat_agent | ||
from src.utils import Config | ||
import json | ||
|
||
logger = logging.getLogger(__name__) | ||
config = Config() | ||
|
||
engine = PromptEngine() | ||
|
||
|
||
@chat_agent( | ||
name="GeneralistAgent", | ||
description="This agent attempts to answer a general question using only the llm", | ||
tools=[], | ||
) | ||
class GeneralistAgent(ChatAgent): | ||
async def invoke(self, utterance) -> str: | ||
try: | ||
answer_to_user = await answer_user_question(utterance, self.llm, self.model) | ||
answer_result = json.loads(answer_to_user) | ||
final_answer = json.loads(answer_result["response"]).get("answer", "") | ||
if not final_answer: | ||
response = {"content": "Error in answer format.", "ignore_validation": "false"} | ||
return json.dumps(response, indent=4) | ||
logger.info(f"Answer found successfully {final_answer}") | ||
response = {"content": final_answer, "ignore_validation": "false"} | ||
return json.dumps(response, indent=4) | ||
|
||
except Exception as e: | ||
logger.error(f"Error in web_general_search_core: {e}") | ||
return "An error occurred while processing the search query." | ||
|
||
|
||
async def answer_user_question(search_query, llm, model) -> str: | ||
try: | ||
summariser_prompt = engine.load_prompt("generalist-answer", question=search_query) | ||
response = await llm.chat(model, summariser_prompt, "") | ||
return json.dumps( | ||
{ | ||
"status": "success", | ||
"response": response, | ||
"error": None, | ||
} | ||
) | ||
except Exception as e: | ||
logger.error(f"Error during create search term: {e}") | ||
return json.dumps( | ||
{ | ||
"status": "error", | ||
"response": None, | ||
"error": str(e), | ||
} | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
You are an expert in providing accurate and complete answers to user queries regarding ESG. Your task is to **Generate a detailed answer** to the user's question. | ||
|
||
User's question is: | ||
{{ question }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import pytest | ||
from unittest.mock import patch, AsyncMock | ||
import json | ||
from src.agents.generalist_agent import GeneralistAgent | ||
|
||
|
||
@pytest.mark.asyncio | ||
@patch("src.agents.generalist_agent.answer_user_question", new_callable=AsyncMock) | ||
async def test_generalist_agent( | ||
mock_answer_user_question, | ||
): | ||
mock_answer_user_question.return_value = json.dumps( | ||
{"status": "success", "response": json.dumps({"is_valid": True, "answer": "Example summary."})} | ||
) | ||
generalist_agent = GeneralistAgent("llm", "mock_model") | ||
|
||
result = await generalist_agent.invoke("example query") | ||
expected_response = {"content": "Example summary.", "ignore_validation": "false"} | ||
assert json.loads(result) == expected_response | ||
|
||
|
||
@pytest.mark.asyncio | ||
@patch("src.agents.generalist_agent.answer_user_question", new_callable=AsyncMock) | ||
async def test_generalist_agent_reponse_format_error( | ||
mock_answer_user_question, | ||
): | ||
mock_answer_user_question.return_value = json.dumps( | ||
{"status": "success", "response": json.dumps({"is_valid": True, "answer_wrong_format": "Example summary."})} | ||
) | ||
generalist_agent = GeneralistAgent("llm", "mock_model") | ||
|
||
result = await generalist_agent.invoke("example query") | ||
|
||
expected_response = {"content": "Error in answer format.", "ignore_validation": "false"} | ||
assert json.loads(result) == expected_response |
Oops, something went wrong.