Skip to content

Commit

Permalink
Merge pull request WaitThatShouldntWork#97 from WaitThatShouldntWork/…
Browse files Browse the repository at this point in the history
…feature/accurate-retrieval-benchmark

Feature/accurate retrieval benchmark
  • Loading branch information
gaganahluwalia authored Oct 18, 2024
2 parents 5b3caa8 + 8aeccdc commit 70524a8
Show file tree
Hide file tree
Showing 23 changed files with 611 additions and 175 deletions.
3 changes: 3 additions & 0 deletions backend/src/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from .answer_agent import AnswerAgent
from .chart_generator_agent import ChartGeneratorAgent
from .file_agent import FileAgent
from .maths_agent import MathsAgent


config = Config()

Expand All @@ -34,6 +36,7 @@ def get_available_agents() -> List[Agent]:
WebAgent(config.web_agent_llm, config.web_agent_model),
ChartGeneratorAgent(config.chart_generator_llm, config.chart_generator_model),
FileAgent(config.file_agent_llm, config.file_agent_model),
MathsAgent(config.maths_agent_llm, config.maths_agent_model),
]


Expand Down
1 change: 0 additions & 1 deletion backend/src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ async def __get_action(self, utterance: str) -> Action_and_args:

async def invoke(self, utterance: str) -> str:
(action, args) = await self.__get_action(utterance)
logger.info(f"USER - Action: {action} and args: {args} for utterance: {utterance}")
result_of_action = await action(**args, llm=self.llm, model=self.model)
await publish_log_info(LogPrefix.USER, f"Action gave result: {result_of_action}", __name__)
return result_of_action
Expand Down
11 changes: 6 additions & 5 deletions backend/src/agents/chart_generator_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from src.utils import scratchpad
from PIL import Image
import json
from src.websockets.user_confirmer import UserConfirmer
from src.websockets.confirmations_manager import confirmations_manager
# from src.websockets.user_confirmer import UserConfirmer
# from src.websockets.confirmations_manager import confirmations_manager

logger = logging.getLogger(__name__)

Expand All @@ -31,8 +31,9 @@ async def generate_chart(question_intent, data_provided, question_params, llm: L
sanitised_script = sanitise_script(generated_code)

try:
confirmer = UserConfirmer(confirmations_manager)
is_confirmed = await confirmer.confirm("Would you like to generate a graph?")
# confirmer = UserConfirmer(confirmations_manager)
is_confirmed = True
# await confirmer.confirm("Would you like to generate a graph?")
if not is_confirmed:
raise Exception("The user did not confirm to creating a graph.")
local_vars = {}
Expand All @@ -51,7 +52,7 @@ async def generate_chart(question_intent, data_provided, question_params, llm: L
raise
response = {
"content": image_data,
"ignore_validation": "false",
"ignore_validation": "true",
}
return json.dumps(response, indent=4)

Expand Down
28 changes: 19 additions & 9 deletions backend/src/agents/file_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,19 @@ async def read_file_core(file_path: str) -> str:
return create_response(f"Error reading file: {file_path}", STATUS_ERROR)


async def write_file_core(file_path: str, content: str) -> str:
async def write_or_update_file_core(file_path: str, content: str, update) -> str:
full_path = os.path.normpath(os.path.join(FILES_DIRECTORY, file_path))
try:
with open(full_path, 'w') as file:
file.write(content)
logger.info(f"Content written to file {full_path} successfully.")
return create_response(f"Content written to file {file_path}.")
if update == "yes":
with open(full_path, 'a') as file:
file.write('\n' +content)
logger.info(f"Content appended to file {full_path} successfully.")
return create_response(f"Content appended to file {file_path}.")
else:
with open(full_path, 'w') as file:
file.write(content)
logger.info(f"Content written to file {full_path} successfully.")
return create_response(f"Content written to file {file_path}.")
except Exception as e:
logger.error(f"Error writing to file {full_path}: {e}")
return create_response(f"Error writing to file: {file_path}", STATUS_ERROR)
Expand All @@ -67,7 +73,7 @@ async def read_file(file_path: str, llm, model) -> str:

@tool(
name="write_file",
description="Write content to a text file.",
description="Write or update content to a text file.",
parameters={
"file_path": Parameter(
type="string",
Expand All @@ -77,16 +83,20 @@ async def read_file(file_path: str, llm, model) -> str:
type="string",
description="The content to write to the file."
),
"update": Parameter(
type="string",
description="if yes then just append the file"
),
},
)
async def write_file(file_path: str, content: str, llm, model) -> str:
return await write_file_core(file_path, content)
async def write_or_update_file(file_path: str, content: str, update, llm, model) -> str:
return await write_or_update_file_core(file_path, content, update)


@agent(
name="FileAgent",
description="This agent is responsible for reading from and writing to files.",
tools=[read_file, write_file],
tools=[read_file, write_or_update_file],
)
class FileAgent(Agent):
pass
41 changes: 40 additions & 1 deletion backend/src/agents/intent_agent.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,56 @@
from src.prompts import PromptEngine
from src.agents import Agent, agent
import logging
import os
import json
from src.utils.config import Config


config = Config()

engine = PromptEngine()
intent_format = engine.load_prompt("intent-format")
logger = logging.getLogger(__name__)
FILES_DIRECTORY = f"/app/{config.files_directory}"

# Constants for response status
IGNORE_VALIDATION = "true"
STATUS_SUCCESS = "success"
STATUS_ERROR = "error"

@agent(
name="IntentAgent",
description="This agent is responsible for determining the intent of the user's utterance",
tools=[],
)
class IntentAgent(Agent):

async def read_file_core(self, file_path: str) -> str:
full_path = os.path.normpath(os.path.join(FILES_DIRECTORY, file_path))
try:
with open(full_path, 'r') as file:
content = file.read()
return content
except FileNotFoundError:
error_message = f"File {file_path} not found."
logger.error(error_message)
return ""
except Exception as e:
logger.error(f"Error reading file {full_path}: {e}")
return ""

async def invoke(self, utterance: str) -> str:
user_prompt = engine.load_prompt("intent", question=utterance)
chat_history = await self.read_file_core("conversation-history.txt")

user_prompt = engine.load_prompt("intent", question=utterance, chat_history=chat_history)

return await self.llm.chat(self.model, intent_format, user_prompt=user_prompt, return_json=True)


# Utility function for error responses
def create_response(content: str, status: str = STATUS_SUCCESS) -> str:
return json.dumps({
"content": content,
"ignore_validation": IGNORE_VALIDATION,
"status": status
}, indent=4)
117 changes: 79 additions & 38 deletions backend/src/agents/maths_agent.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,98 @@
from .tool import tool
from .agent_types import Parameter
from .agent import Agent, agent
import logging
from src.utils import Config
from .validator_agent import ValidatorAgent
import json
from src.utils.web_utils import perform_math_operation_util

logger = logging.getLogger(__name__)
config = Config()

@tool(
name="sum list of values",
description="sums a list of provided values",
parameters={
"list_of_values": Parameter(
type="list[number]",
description="Python list of comma separated values (e.g. [1, 5, 3])",
)
},
)
async def sum_list_of_values(list_of_values) -> str:
if not isinstance(list_of_values, list):
raise Exception("Method not passed a valid Python list")
return f"The sum of all the values passed {list_of_values} is {str(sum(list_of_values))}"
async def perform_math_operation_core(math_query, llm, model) -> str:
try:
# Call the utility function to perform the math operation
math_operation_result = await perform_math_operation_util(math_query, llm, model)

result_json = json.loads(math_operation_result)

if result_json.get("status") == "success":
# Extract the relevant response (math result) from the utility function's output
response = result_json.get("response", {})
response_json = json.loads(response)
result = response_json.get("result", "")
if result:
logger.info(f"Math operation successful: {result}")
is_valid = await is_valid_answer(result, math_query)
logger.info(f"Is the answer valid: {is_valid}")
if is_valid:
response = {
"content": result,
"ignore_validation": "true"
}
return json.dumps(response, indent=4)
else:
response = {
"content": "No valid result found for the math query.",
"ignore_validation": "true"
}
return json.dumps(response, indent=4)
else:
response = {
"content": None,
"status": "error"
}
return json.dumps(response, indent=4)
except Exception as e:
logger.error(f"Error in perform_math_operation_core: {e}")
response = {
"content": None,
"status": "error"
}
return json.dumps(response, indent=4)

# Ensure a return statement in all code paths
response = {
"content": None,
"status": "error"
}
return json.dumps(response, indent=4)

def get_validator_agent() -> Agent:
return ValidatorAgent(config.validator_agent_llm, config.validator_agent_model)

async def is_valid_answer(answer, task) -> bool:
is_valid = (await get_validator_agent().invoke(f"Task: {task} Answer: {answer}")).lower() == "true"
return is_valid

# Math Operation Tool
@tool(
name="compare two values",
description="Compare two passed values and return information on which one is greater",
name="perform_math_operation",
description=(
"Use this tool to perform complex mathematical operations or calculations. "
"It handles arithmetic operations and algebra, and also supports conversions to specific units like millions,"
"rounding when necessary. Returns both the result and an explanation of the steps involved."
),
parameters={
"thing_one": Parameter(
type="string",
description="first thing for comparison",
),
"value_one": Parameter(
type="number",
description="value of first thing",
),
"thing_two": Parameter(
"math_query": Parameter(
type="string",
description="second thing for comparison",
),
"value_two": Parameter(
type="number",
description="value of first thing",
description="The mathematical query or equation to solve."
),
},
)
async def compare_two_values(value_one, thing_one, value_two, thing_two) -> str:
if value_one > value_two:
return f"You have spent more on {thing_one} ({value_one}) than {thing_two} ({value_two}) in the last month"
else:
return f"You have spent more on {thing_two} ({value_two}) than {thing_one} ({value_one}) in the last month"

async def perform_math_operation(math_query, llm, model) -> str:
return await perform_math_operation_core(math_query, llm, model)

# MathAgent definition
@agent(
name="MathsAgent",
description="This agent is responsible for solving number comparison and calculation tasks",
tools=[sum_list_of_values, compare_two_values],
description=(
"This agent processes mathematical queries, performs calculations, and applies necessary formatting such as"
"rounding or converting results into specific units (e.g., millions). "
"It provides clear explanations of the steps involved to ensure accuracy."
),
tools=[perform_math_operation],
)
class MathsAgent(Agent):
pass
Loading

0 comments on commit 70524a8

Please sign in to comment.