forked from ScottLogic/InferLLM
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request WaitThatShouldntWork#97 from WaitThatShouldntWork/…
…feature/accurate-retrieval-benchmark Feature/accurate retrieval benchmark
- Loading branch information
Showing
23 changed files
with
611 additions
and
175 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,56 @@ | ||
from src.prompts import PromptEngine | ||
from src.agents import Agent, agent | ||
import logging | ||
import os | ||
import json | ||
from src.utils.config import Config | ||
|
||
|
||
config = Config() | ||
|
||
engine = PromptEngine() | ||
intent_format = engine.load_prompt("intent-format") | ||
logger = logging.getLogger(__name__) | ||
FILES_DIRECTORY = f"/app/{config.files_directory}" | ||
|
||
# Constants for response status | ||
IGNORE_VALIDATION = "true" | ||
STATUS_SUCCESS = "success" | ||
STATUS_ERROR = "error" | ||
|
||
@agent( | ||
name="IntentAgent", | ||
description="This agent is responsible for determining the intent of the user's utterance", | ||
tools=[], | ||
) | ||
class IntentAgent(Agent): | ||
|
||
async def read_file_core(self, file_path: str) -> str: | ||
full_path = os.path.normpath(os.path.join(FILES_DIRECTORY, file_path)) | ||
try: | ||
with open(full_path, 'r') as file: | ||
content = file.read() | ||
return content | ||
except FileNotFoundError: | ||
error_message = f"File {file_path} not found." | ||
logger.error(error_message) | ||
return "" | ||
except Exception as e: | ||
logger.error(f"Error reading file {full_path}: {e}") | ||
return "" | ||
|
||
async def invoke(self, utterance: str) -> str: | ||
user_prompt = engine.load_prompt("intent", question=utterance) | ||
chat_history = await self.read_file_core("conversation-history.txt") | ||
|
||
user_prompt = engine.load_prompt("intent", question=utterance, chat_history=chat_history) | ||
|
||
return await self.llm.chat(self.model, intent_format, user_prompt=user_prompt, return_json=True) | ||
|
||
|
||
# Utility function for error responses | ||
def create_response(content: str, status: str = STATUS_SUCCESS) -> str: | ||
return json.dumps({ | ||
"content": content, | ||
"ignore_validation": IGNORE_VALIDATION, | ||
"status": status | ||
}, indent=4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,57 +1,98 @@ | ||
from .tool import tool | ||
from .agent_types import Parameter | ||
from .agent import Agent, agent | ||
import logging | ||
from src.utils import Config | ||
from .validator_agent import ValidatorAgent | ||
import json | ||
from src.utils.web_utils import perform_math_operation_util | ||
|
||
logger = logging.getLogger(__name__) | ||
config = Config() | ||
|
||
@tool( | ||
name="sum list of values", | ||
description="sums a list of provided values", | ||
parameters={ | ||
"list_of_values": Parameter( | ||
type="list[number]", | ||
description="Python list of comma separated values (e.g. [1, 5, 3])", | ||
) | ||
}, | ||
) | ||
async def sum_list_of_values(list_of_values) -> str: | ||
if not isinstance(list_of_values, list): | ||
raise Exception("Method not passed a valid Python list") | ||
return f"The sum of all the values passed {list_of_values} is {str(sum(list_of_values))}" | ||
async def perform_math_operation_core(math_query, llm, model) -> str: | ||
try: | ||
# Call the utility function to perform the math operation | ||
math_operation_result = await perform_math_operation_util(math_query, llm, model) | ||
|
||
result_json = json.loads(math_operation_result) | ||
|
||
if result_json.get("status") == "success": | ||
# Extract the relevant response (math result) from the utility function's output | ||
response = result_json.get("response", {}) | ||
response_json = json.loads(response) | ||
result = response_json.get("result", "") | ||
if result: | ||
logger.info(f"Math operation successful: {result}") | ||
is_valid = await is_valid_answer(result, math_query) | ||
logger.info(f"Is the answer valid: {is_valid}") | ||
if is_valid: | ||
response = { | ||
"content": result, | ||
"ignore_validation": "true" | ||
} | ||
return json.dumps(response, indent=4) | ||
else: | ||
response = { | ||
"content": "No valid result found for the math query.", | ||
"ignore_validation": "true" | ||
} | ||
return json.dumps(response, indent=4) | ||
else: | ||
response = { | ||
"content": None, | ||
"status": "error" | ||
} | ||
return json.dumps(response, indent=4) | ||
except Exception as e: | ||
logger.error(f"Error in perform_math_operation_core: {e}") | ||
response = { | ||
"content": None, | ||
"status": "error" | ||
} | ||
return json.dumps(response, indent=4) | ||
|
||
# Ensure a return statement in all code paths | ||
response = { | ||
"content": None, | ||
"status": "error" | ||
} | ||
return json.dumps(response, indent=4) | ||
|
||
def get_validator_agent() -> Agent: | ||
return ValidatorAgent(config.validator_agent_llm, config.validator_agent_model) | ||
|
||
async def is_valid_answer(answer, task) -> bool: | ||
is_valid = (await get_validator_agent().invoke(f"Task: {task} Answer: {answer}")).lower() == "true" | ||
return is_valid | ||
|
||
# Math Operation Tool | ||
@tool( | ||
name="compare two values", | ||
description="Compare two passed values and return information on which one is greater", | ||
name="perform_math_operation", | ||
description=( | ||
"Use this tool to perform complex mathematical operations or calculations. " | ||
"It handles arithmetic operations and algebra, and also supports conversions to specific units like millions," | ||
"rounding when necessary. Returns both the result and an explanation of the steps involved." | ||
), | ||
parameters={ | ||
"thing_one": Parameter( | ||
type="string", | ||
description="first thing for comparison", | ||
), | ||
"value_one": Parameter( | ||
type="number", | ||
description="value of first thing", | ||
), | ||
"thing_two": Parameter( | ||
"math_query": Parameter( | ||
type="string", | ||
description="second thing for comparison", | ||
), | ||
"value_two": Parameter( | ||
type="number", | ||
description="value of first thing", | ||
description="The mathematical query or equation to solve." | ||
), | ||
}, | ||
) | ||
async def compare_two_values(value_one, thing_one, value_two, thing_two) -> str: | ||
if value_one > value_two: | ||
return f"You have spent more on {thing_one} ({value_one}) than {thing_two} ({value_two}) in the last month" | ||
else: | ||
return f"You have spent more on {thing_two} ({value_two}) than {thing_one} ({value_one}) in the last month" | ||
|
||
async def perform_math_operation(math_query, llm, model) -> str: | ||
return await perform_math_operation_core(math_query, llm, model) | ||
|
||
# MathAgent definition | ||
@agent( | ||
name="MathsAgent", | ||
description="This agent is responsible for solving number comparison and calculation tasks", | ||
tools=[sum_list_of_values, compare_two_values], | ||
description=( | ||
"This agent processes mathematical queries, performs calculations, and applies necessary formatting such as" | ||
"rounding or converting results into specific units (e.g., millions). " | ||
"It provides clear explanations of the steps involved to ensure accuracy." | ||
), | ||
tools=[perform_math_operation], | ||
) | ||
class MathsAgent(Agent): | ||
pass |
Oops, something went wrong.