diff --git a/.env.example b/.env.example index df8da3280..bcb17ddb2 100644 --- a/.env.example +++ b/.env.example @@ -13,6 +13,9 @@ NEO4J_URI=bolt://localhost:7687 NEO4J_HTTP_PORT=7474 NEO4J_BOLT_PORT=7687 +# files location +FILES_DIRECTORY=files + # backend LLM properties MISTRAL_KEY=my-api-key @@ -42,6 +45,7 @@ MATHS_AGENT_LLM="openai" WEB_AGENT_LLM="openai" CHART_GENERATOR_LLM="openai" ROUTER_LLM="openai" +FILE_AGENT_LLM="openai" # llm model ANSWER_AGENT_MODEL="gpt-4o mini" @@ -52,3 +56,4 @@ MATHS_AGENT_MODEL="gpt-4o mini" WEB_AGENT_MODEL="gpt-4o mini" CHART_GENERATOR_MODEL="gpt-4o mini" ROUTER_MODEL="gpt-4o mini" +FILE_AGENT_MODEL="gpt-4o mini" diff --git a/.gitignore b/.gitignore index a6d19fb54..7d7996c23 100644 --- a/.gitignore +++ b/.gitignore @@ -127,6 +127,7 @@ celerybeat.pid # Environments .env .venv +files env/ venv/ ENV/ diff --git a/backend/README.md b/backend/README.md index 9e09bcd62..cd21c0b95 100644 --- a/backend/README.md +++ b/backend/README.md @@ -37,7 +37,7 @@ Follow the instructions below to run the backend locally. Change directory to `/ ```bash pip install -r requirements.txt ``` - +> (VsCode) You may run into some issues with compiling python packages from requirements.txt. To resolve this ensure you have downloaded and installed the "Desktop development with C++" workload from your Visual Studio installer. 3. Run the app ```bash diff --git a/backend/requirements.txt b/backend/requirements.txt index 3dae443a3..59917b6dd 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,6 +1,6 @@ fastapi==0.110.0 uvicorn==0.29.0 -mistralai==0.1.8 +mistralai==1.1.0 pycodestyle==2.11.1 python-dotenv==1.0.1 neo4j==5.18.0 diff --git a/backend/src/agents/__init__.py b/backend/src/agents/__init__.py index cd58158db..c06af259f 100644 --- a/backend/src/agents/__init__.py +++ b/backend/src/agents/__init__.py @@ -8,6 +8,9 @@ from .validator_agent import ValidatorAgent from .answer_agent import AnswerAgent from .chart_generator_agent import ChartGeneratorAgent +from .file_agent import FileAgent +from .maths_agent import MathsAgent + config = Config() @@ -32,6 +35,8 @@ def get_available_agents() -> List[Agent]: return [DatastoreAgent(config.datastore_agent_llm, config.datastore_agent_model), WebAgent(config.web_agent_llm, config.web_agent_model), ChartGeneratorAgent(config.chart_generator_llm, config.chart_generator_model), + FileAgent(config.file_agent_llm, config.file_agent_model), + MathsAgent(config.maths_agent_llm, config.maths_agent_model), ] diff --git a/backend/src/agents/agent.py b/backend/src/agents/agent.py index ada5e9f47..32224b55d 100644 --- a/backend/src/agents/agent.py +++ b/backend/src/agents/agent.py @@ -56,7 +56,6 @@ async def __get_action(self, utterance: str) -> Action_and_args: async def invoke(self, utterance: str) -> str: (action, args) = await self.__get_action(utterance) - logger.info(f"USER - Action: {action} and args: {args} for utterance: {utterance}") result_of_action = await action(**args, llm=self.llm, model=self.model) await publish_log_info(LogPrefix.USER, f"Action gave result: {result_of_action}", __name__) return result_of_action diff --git a/backend/src/agents/chart_generator_agent.py b/backend/src/agents/chart_generator_agent.py index 156e51703..8479bc128 100644 --- a/backend/src/agents/chart_generator_agent.py +++ b/backend/src/agents/chart_generator_agent.py @@ -9,11 +9,14 @@ from src.utils import scratchpad from PIL import Image import json +# from src.websockets.user_confirmer import UserConfirmer +# from src.websockets.confirmations_manager import confirmations_manager logger = logging.getLogger(__name__) engine = PromptEngine() + async def generate_chart(question_intent, data_provided, question_params, llm: LLM, model) -> str: details_to_generate_chart_code = engine.load_prompt( "details-to-generate-chart-code", @@ -28,13 +31,18 @@ async def generate_chart(question_intent, data_provided, question_params, llm: L sanitised_script = sanitise_script(generated_code) try: + # confirmer = UserConfirmer(confirmations_manager) + is_confirmed = True + # await confirmer.confirm("Would you like to generate a graph?") + if not is_confirmed: + raise Exception("The user did not confirm to creating a graph.") local_vars = {} exec(sanitised_script, {}, local_vars) - fig = local_vars.get('fig') + fig = local_vars.get("fig") buf = BytesIO() if fig is None: raise ValueError("The generated code did not produce a figure named 'fig'.") - fig.savefig(buf, format='png') + fig.savefig(buf, format="png") buf.seek(0) with Image.open(buf): image_data = base64.b64encode(buf.getvalue()).decode("utf-8") @@ -44,7 +52,7 @@ async def generate_chart(question_intent, data_provided, question_params, llm: L raise response = { "content": image_data, - "ignore_validation": "false" + "ignore_validation": "true", } return json.dumps(response, indent=4) @@ -57,6 +65,7 @@ def sanitise_script(script: str) -> str: script = script[:-3] return script.strip() + @tool( name="generate_code_chart", description="Generate Matplotlib bar chart code if the user's query involves creating a chart", @@ -74,18 +83,18 @@ def sanitise_script(script: str) -> str: description=""" The specific parameters required for the question to be answered with the question_intent, extracted from data_provided - """), - } + """, + ), + }, ) - async def generate_code_chart(question_intent, data_provided, question_params, llm: LLM, model) -> str: return await generate_chart(question_intent, data_provided, question_params, llm, model) + @agent( name="ChartGeneratorAgent", description="This agent is responsible for creating charts", - tools=[generate_code_chart] + tools=[generate_code_chart], ) - class ChartGeneratorAgent(Agent): pass diff --git a/backend/src/agents/datastore_agent.py b/backend/src/agents/datastore_agent.py index c597fcb2c..f876f9a0b 100644 --- a/backend/src/agents/datastore_agent.py +++ b/backend/src/agents/datastore_agent.py @@ -1,3 +1,4 @@ +import json import logging from src.llm.llm import LLM from src.utils.graph_db_utils import execute_query @@ -8,15 +9,51 @@ from src.utils.log_publisher import LogPrefix, publish_log_info from .agent import Agent, agent from .tool import tool -import json - +from src.utils.semantic_layer_builder import get_semantic_layer logger = logging.getLogger(__name__) engine = PromptEngine() -graph_schema = engine.load_prompt("graph-schema") +cache = {} +async def generate_cypher_query_core( + question_intent, operation, question_params, aggregation, sort_order, timeframe, llm: LLM, model +) -> str: + + details_to_create_cypher_query = engine.load_prompt( + "details-to-create-cypher-query", + question_intent=question_intent, + operation=operation, + question_params=question_params, + aggregation=aggregation, + sort_order=sort_order, + timeframe=timeframe, + ) + try: + graph_schema = await get_semantic_layer_cache(llm, model, cache) + graph_schema = json.dumps(graph_schema, separators=(",", ":")) + + generate_cypher_query_prompt = engine.load_prompt( + "generate-cypher-query", graph_schema=graph_schema, current_date=datetime.now() + ) + + llm_query = await llm.chat(model, generate_cypher_query_prompt, details_to_create_cypher_query, + return_json=True) + json_query = to_json(llm_query) + await publish_log_info(LogPrefix.USER, f"Cypher generated by the LLM: {llm_query}", __name__) + if json_query["query"] == "None": + return "No database query" + db_response = execute_query(json_query["query"]) + await publish_log_info(LogPrefix.USER, f"Database response: {db_response}", __name__) + except Exception as e: + logger.error(f"Error during data retrieval: {e}") + raise + response = { + "content": db_response, + "ignore_validation": "false" + } + return json.dumps(response, indent=4) @tool( name="generate cypher query", @@ -51,39 +88,26 @@ ), }, ) -async def generate_query( - question_intent, operation, question_params, aggregation, sort_order, timeframe, llm: LLM, model -) -> str: - details_to_create_cypher_query = engine.load_prompt( - "details-to-create-cypher-query", - question_intent=question_intent, - operation=operation, - question_params=question_params, - aggregation=aggregation, - sort_order=sort_order, - timeframe=timeframe, - ) - generate_cypher_query_prompt = engine.load_prompt( - "generate-cypher-query", graph_schema=graph_schema, current_date=datetime.now() - ) - llm_query = await llm.chat(model, generate_cypher_query_prompt, details_to_create_cypher_query, return_json=True) - json_query = to_json(llm_query) - await publish_log_info(LogPrefix.USER, f"Cypher generated by the LLM: {llm_query}", __name__) - if json_query["query"] == "None": - return "No database query" - db_response = execute_query(json_query["query"]) - await publish_log_info(LogPrefix.USER, f"Database response: {db_response}", __name__) - response = { - "content": db_response, - "ignore_validation": "false" - } - return json.dumps(response, indent=4) +async def generate_cypher(question_intent, operation, question_params, aggregation, sort_order, + timeframe, llm: LLM, model) -> str: + return await generate_cypher_query_core(question_intent, operation, question_params, aggregation, sort_order, + timeframe, llm, model) + + +async def get_semantic_layer_cache(llm, model, graph_schema): + global cache + if not cache: + graph_schema = await get_semantic_layer(llm, model) + cache = graph_schema + return cache + else: + return cache @agent( name="DatastoreAgent", description="This agent is responsible for handling database queries relating to the user's personal data.", - tools=[generate_query], + tools=[generate_cypher], ) class DatastoreAgent(Agent): pass diff --git a/backend/src/agents/file_agent.py b/backend/src/agents/file_agent.py new file mode 100644 index 000000000..d8a817b1c --- /dev/null +++ b/backend/src/agents/file_agent.py @@ -0,0 +1,102 @@ +import logging +from .agent_types import Parameter +from .agent import Agent, agent +from .tool import tool +import json +import os +from src.utils.config import Config + +logger = logging.getLogger(__name__) +config = Config() + +FILES_DIRECTORY = f"/app/{config.files_directory}" + +# Constants for response status +IGNORE_VALIDATION = "true" +STATUS_SUCCESS = "success" +STATUS_ERROR = "error" + +# Utility function for error responses +def create_response(content: str, status: str = STATUS_SUCCESS) -> str: + return json.dumps({ + "content": content, + "ignore_validation": IGNORE_VALIDATION, + "status": status + }, indent=4) + +async def read_file_core(file_path: str) -> str: + full_path = os.path.normpath(os.path.join(FILES_DIRECTORY, file_path)) + try: + with open(full_path, 'r') as file: + content = file.read() + return create_response(content) + except FileNotFoundError: + error_message = f"File {file_path} not found." + logger.error(error_message) + return create_response(error_message, STATUS_ERROR) + except Exception as e: + logger.error(f"Error reading file {full_path}: {e}") + return create_response(f"Error reading file: {file_path}", STATUS_ERROR) + + +async def write_or_update_file_core(file_path: str, content: str, update) -> str: + full_path = os.path.normpath(os.path.join(FILES_DIRECTORY, file_path)) + try: + if update == "yes": + with open(full_path, 'a') as file: + file.write('\n' +content) + logger.info(f"Content appended to file {full_path} successfully.") + return create_response(f"Content appended to file {file_path}.") + else: + with open(full_path, 'w') as file: + file.write(content) + logger.info(f"Content written to file {full_path} successfully.") + return create_response(f"Content written to file {file_path}.") + except Exception as e: + logger.error(f"Error writing to file {full_path}: {e}") + return create_response(f"Error writing to file: {file_path}", STATUS_ERROR) + + +@tool( + name="read_file", + description="Read the content of a text file.", + parameters={ + "file_path": Parameter( + type="string", + description="The path to the file to be read." + ), + }, +) +async def read_file(file_path: str, llm, model) -> str: + return await read_file_core(file_path) + + +@tool( + name="write_file", + description="Write or update content to a text file.", + parameters={ + "file_path": Parameter( + type="string", + description="The path to the file where the content will be written." + ), + "content": Parameter( + type="string", + description="The content to write to the file." + ), + "update": Parameter( + type="string", + description="if yes then just append the file" + ), + }, +) +async def write_or_update_file(file_path: str, content: str, update, llm, model) -> str: + return await write_or_update_file_core(file_path, content, update) + + +@agent( + name="FileAgent", + description="This agent is responsible for reading from and writing to files.", + tools=[read_file, write_or_update_file], +) +class FileAgent(Agent): + pass diff --git a/backend/src/agents/intent_agent.py b/backend/src/agents/intent_agent.py index ed6b7fcb2..f3b701653 100644 --- a/backend/src/agents/intent_agent.py +++ b/backend/src/agents/intent_agent.py @@ -1,9 +1,22 @@ from src.prompts import PromptEngine from src.agents import Agent, agent +import logging +import os +import json +from src.utils.config import Config + + +config = Config() engine = PromptEngine() intent_format = engine.load_prompt("intent-format") +logger = logging.getLogger(__name__) +FILES_DIRECTORY = f"/app/{config.files_directory}" +# Constants for response status +IGNORE_VALIDATION = "true" +STATUS_SUCCESS = "success" +STATUS_ERROR = "error" @agent( name="IntentAgent", @@ -11,7 +24,33 @@ tools=[], ) class IntentAgent(Agent): + + async def read_file_core(self, file_path: str) -> str: + full_path = os.path.normpath(os.path.join(FILES_DIRECTORY, file_path)) + try: + with open(full_path, 'r') as file: + content = file.read() + return content + except FileNotFoundError: + error_message = f"File {file_path} not found." + logger.error(error_message) + return "" + except Exception as e: + logger.error(f"Error reading file {full_path}: {e}") + return "" + async def invoke(self, utterance: str) -> str: - user_prompt = engine.load_prompt("intent", question=utterance) + chat_history = await self.read_file_core("conversation-history.txt") + + user_prompt = engine.load_prompt("intent", question=utterance, chat_history=chat_history) return await self.llm.chat(self.model, intent_format, user_prompt=user_prompt, return_json=True) + + + # Utility function for error responses +def create_response(content: str, status: str = STATUS_SUCCESS) -> str: + return json.dumps({ + "content": content, + "ignore_validation": IGNORE_VALIDATION, + "status": status + }, indent=4) diff --git a/backend/src/agents/maths_agent.py b/backend/src/agents/maths_agent.py index e8833ffc5..a55506d8f 100644 --- a/backend/src/agents/maths_agent.py +++ b/backend/src/agents/maths_agent.py @@ -1,57 +1,98 @@ from .tool import tool from .agent_types import Parameter from .agent import Agent, agent +import logging +from src.utils import Config +from .validator_agent import ValidatorAgent +import json +from src.utils.web_utils import perform_math_operation_util +logger = logging.getLogger(__name__) +config = Config() -@tool( - name="sum list of values", - description="sums a list of provided values", - parameters={ - "list_of_values": Parameter( - type="list[number]", - description="Python list of comma separated values (e.g. [1, 5, 3])", - ) - }, -) -async def sum_list_of_values(list_of_values) -> str: - if not isinstance(list_of_values, list): - raise Exception("Method not passed a valid Python list") - return f"The sum of all the values passed {list_of_values} is {str(sum(list_of_values))}" +async def perform_math_operation_core(math_query, llm, model) -> str: + try: + # Call the utility function to perform the math operation + math_operation_result = await perform_math_operation_util(math_query, llm, model) + + result_json = json.loads(math_operation_result) + + if result_json.get("status") == "success": + # Extract the relevant response (math result) from the utility function's output + response = result_json.get("response", {}) + response_json = json.loads(response) + result = response_json.get("result", "") + if result: + logger.info(f"Math operation successful: {result}") + is_valid = await is_valid_answer(result, math_query) + logger.info(f"Is the answer valid: {is_valid}") + if is_valid: + response = { + "content": result, + "ignore_validation": "true" + } + return json.dumps(response, indent=4) + else: + response = { + "content": "No valid result found for the math query.", + "ignore_validation": "true" + } + return json.dumps(response, indent=4) + else: + response = { + "content": None, + "status": "error" + } + return json.dumps(response, indent=4) + except Exception as e: + logger.error(f"Error in perform_math_operation_core: {e}") + response = { + "content": None, + "status": "error" + } + return json.dumps(response, indent=4) + + # Ensure a return statement in all code paths + response = { + "content": None, + "status": "error" + } + return json.dumps(response, indent=4) +def get_validator_agent() -> Agent: + return ValidatorAgent(config.validator_agent_llm, config.validator_agent_model) +async def is_valid_answer(answer, task) -> bool: + is_valid = (await get_validator_agent().invoke(f"Task: {task} Answer: {answer}")).lower() == "true" + return is_valid + +# Math Operation Tool @tool( - name="compare two values", - description="Compare two passed values and return information on which one is greater", + name="perform_math_operation", + description=( + "Use this tool to perform complex mathematical operations or calculations. " + "It handles arithmetic operations and algebra, and also supports conversions to specific units like millions," + "rounding when necessary. Returns both the result and an explanation of the steps involved." + ), parameters={ - "thing_one": Parameter( - type="string", - description="first thing for comparison", - ), - "value_one": Parameter( - type="number", - description="value of first thing", - ), - "thing_two": Parameter( + "math_query": Parameter( type="string", - description="second thing for comparison", - ), - "value_two": Parameter( - type="number", - description="value of first thing", + description="The mathematical query or equation to solve." ), }, ) -async def compare_two_values(value_one, thing_one, value_two, thing_two) -> str: - if value_one > value_two: - return f"You have spent more on {thing_one} ({value_one}) than {thing_two} ({value_two}) in the last month" - else: - return f"You have spent more on {thing_two} ({value_two}) than {thing_one} ({value_one}) in the last month" - +async def perform_math_operation(math_query, llm, model) -> str: + return await perform_math_operation_core(math_query, llm, model) +# MathAgent definition @agent( name="MathsAgent", - description="This agent is responsible for solving number comparison and calculation tasks", - tools=[sum_list_of_values, compare_two_values], + description=( + "This agent processes mathematical queries, performs calculations, and applies necessary formatting such as" + "rounding or converting results into specific units (e.g., millions). " + "It provides clear explanations of the steps involved to ensure accuracy." + ), + tools=[perform_math_operation], ) class MathsAgent(Agent): pass diff --git a/backend/src/agents/web_agent.py b/backend/src/agents/web_agent.py index 50c9be5a1..714a8c0f3 100644 --- a/backend/src/agents/web_agent.py +++ b/backend/src/agents/web_agent.py @@ -4,7 +4,15 @@ from .agent import Agent, agent from .tool import tool from src.utils import Config -from src.utils.web_utils import search_urls, scrape_content, summarise_content, summarise_pdf_content +from src.utils.web_utils import ( + search_urls, + scrape_content, + summarise_content, + summarise_pdf_content, + find_info, + create_search_term, + answer_user_ques +) from .validator_agent import ValidatorAgent import aiohttp import io @@ -20,21 +28,68 @@ async def web_general_search_core(search_query, llm, model) -> str: try: - search_result = perform_search(search_query, num_results=15) - if search_result["status"] == "error": + # Step 1: Generate the search term from the user's query + answer_to_user = await answer_user_ques(search_query, llm, model) + answer_result = json.loads(answer_to_user) + if answer_result["status"] == "error": + response = { + "content": "Error in finding the answer.", + "ignore_validation": "false" + } + return json.dumps(response, indent=4) + logger.info(f'Answer found successfully {answer_result}') + valid_answer = json.loads(answer_result["response"]).get("is_valid", "") + if valid_answer: + final_answer = json.loads(answer_result["response"]).get("answer", "") + if not final_answer: + return "No answer found." + logger.info(f'Answer found successfully {final_answer}') + response = { + "content": final_answer, + "ignore_validation": "false" + } + return json.dumps(response, indent=4) + else: + search_term_json = await create_search_term(search_query, llm, model) + search_term_result = json.loads(search_term_json) + + # Check if there was an error in generating the search term + if search_term_result.get("status") == "error": + response = { + "content": search_term_result.get("error"), + "ignore_validation": "false" + } + return json.dumps(response, indent=4) + search_term = json.loads(search_term_result["response"]).get("search_term", "") + + # Step 2: Perform the search using the generated search term + search_result = await perform_search(search_term, num_results=15) + if search_result.get("status") == "error": + return "No relevant information found on the internet for the given query." + urls = search_result.get("urls", []) + logger.info(f"URLs found: {urls}") + + # Step 3: Scrape content from the URLs found + for url in urls: + content = await perform_scrape(url) + if not content: + continue # Skip to the next URL if no content is found + # logger.info(f"Content scraped successfully: {content}") + # Step 4: Summarize the scraped content based on the search term + summary = await perform_summarization(search_term, content, llm, model) + if not summary: + continue # Skip if no summary was generated + + # Step 5: Validate the summarization + is_valid = await is_valid_answer(summary, search_term) + if not is_valid: + continue # Skip if the summarization is not valid + response = { + "content": summary, + "ignore_validation": "false" + } + return json.dumps(response, indent=4) return "No relevant information found on the internet for the given query." - urls = search_result["urls"] - logger.info(f"URLs found: {urls}") - for url in urls: - content = await perform_scrape(url) - if not content: - continue - summary = await perform_summarization(search_query, content, llm, model) - if not summary: - continue - if await is_valid_answer(summary, search_query): - return summary - return "No relevant information found on the internet for the given query." except Exception as e: logger.error(f"Error in web_general_search_core: {e}") return "An error occurred while processing the search query." @@ -96,6 +151,74 @@ async def web_general_search(search_query, llm, model) -> str: async def web_pdf_download(pdf_url, llm, model) -> str: return await web_pdf_download_core(pdf_url, llm, model) +async def web_scrape_core(url: str) -> str: + try: + # Scrape the content from the provided URL + content = await perform_scrape(url) + if not content: + return "No content found at the provided URL." + logger.info(f"Content scraped successfully: {content}") + content = content.replace("\n", " ").replace("\r", " ") + response = { + "content": content, + "ignore_validation": "true" + } + return json.dumps(response, indent=4) + except Exception as e: + return json.dumps({"status": "error", "error": str(e)}) + + +@tool( + name="web_scrape", + description="Scrapes the content from the given URL.", + parameters={ + "url": Parameter( + type="string", + description="The URL of the page to scrape the content from.", + ), + }, +) +async def web_scrape(url: str, llm, model) -> str: + logger.info(f"Scraping the content from URL: {url}") + return await web_scrape_core(url) + + +async def find_information_from_content_core(content: str, question, llm, model) -> str: + try: + find_info_json = await find_info(content, question, llm, model) + info_result = json.loads(find_info_json) + if info_result["status"] == "error": + return "" + final_info = info_result["response"] + if not final_info: + return "No information found from the content." + logger.info(f"Content scraped successfully: {content}") + response = { + "content": final_info, + "ignore_validation": "true" + } + return json.dumps(response, indent=4) + except Exception as e: + logger.error(f"Error finding information: {e}") + return "" + +@tool( + name="find_information_content", + description="Finds the information from the content.", + parameters={ + "content": Parameter( + type="string", + description="The content to find the information from.", + ), + "question": Parameter( + type="string", + description="The question to find the information from the content.", + ), + }, +) +async def find_information_from_content(content: str, question: str, llm, model) -> str: + return await find_information_from_content_core(content, question, llm, model) + def get_validator_agent() -> Agent: return ValidatorAgent(config.validator_agent_llm, config.validator_agent_model) @@ -105,9 +228,9 @@ async def is_valid_answer(answer, task) -> bool: return is_valid -def perform_search(search_query: str, num_results: int) -> Dict[str, Any]: +async def perform_search(search_query: str, num_results: int) -> Dict[str, Any]: try: - search_result_json = search_urls(search_query, num_results=num_results) + search_result_json = await search_urls(search_query, num_results=num_results) return json.loads(search_result_json) except Exception as e: logger.error(f"Error during web search: {e}") @@ -132,7 +255,8 @@ async def perform_summarization(search_query: str, content: str, llm: Any, model summarise_result = json.loads(summarise_result_json) if summarise_result["status"] == "error": return "" - return summarise_result["response"] + logger.info(f"Content summarized successfully: {summarise_result['response']}") + return json.loads(summarise_result["response"])["summary"] except Exception as e: logger.error(f"Error summarizing content: {e}") return "" @@ -150,8 +274,15 @@ async def perform_pdf_summarization(content: str, llm: Any, model: str) -> str: @agent( name="WebAgent", - description="This agent is responsible for handling web search queries and summarizing information from the web.", - tools=[web_general_search, web_pdf_download], + description="""This agent specializes in handling tasks related to web content extraction, search, and + summarization. + It can perform the following functions: + Web scraping: Extracts data from given URLs, enabling tasks like retrieving specific information from web pages. + Finding Information from Content: Extracts specific information from the content provided. + Internet search: Conducts general online searches based on queries, retrieving and summarizing relevant content from + multiple sources. + PDF content extraction: Downloads and summarizes the content of PDF documents from provided URLs.""", + tools=[web_general_search, web_pdf_download, web_scrape, find_information_from_content], ) class WebAgent(Agent): pass diff --git a/backend/src/llm/mistral.py b/backend/src/llm/mistral.py index 8fac39101..18974b4ef 100644 --- a/backend/src/llm/mistral.py +++ b/backend/src/llm/mistral.py @@ -1,5 +1,4 @@ -from mistralai.async_client import MistralAsyncClient -from mistralai.models.chat_completion import ChatCompletionResponse, ChatMessage +from mistralai import Mistral as MistralApi, UserMessage, SystemMessage import logging from src.utils import Config from .llm import LLM @@ -9,21 +8,27 @@ class Mistral(LLM): - client = MistralAsyncClient(api_key=config.mistral_key) + client = MistralApi(api_key=config.mistral_key) async def chat(self, model, system_prompt: str, user_prompt: str, return_json=False) -> str: logger.debug("Called llm. Waiting on response model with prompt {0}.".format(str([system_prompt, user_prompt]))) - response: ChatCompletionResponse = await self.client.chat( + response = await self.client.chat.complete_async( model=model, messages=[ - ChatMessage(role="system", content=system_prompt), - ChatMessage(role="user", content=user_prompt), + SystemMessage(content=system_prompt), + UserMessage(content=user_prompt), ], temperature=0, response_format={"type": "json_object"} if return_json else None, ) - logger.debug('{0} response : "{1}"'.format(model, response.choices[0].message.content)) + if not response or not response.choices: + logger.error("Call to Mistral API failed: No valid response or choices received") + return "An error occurred while processing the request." content = response.choices[0].message.content + if not content: + logger.error("Call to Mistral API failed: message content is None or Unset") + return "An error occurred while processing the request." - return content if isinstance(content, str) else " ".join(content) + logger.debug('{0} response : "{1}"'.format(model, content)) + return content diff --git a/backend/src/prompts/templates/answer-user-ques.j2 b/backend/src/prompts/templates/answer-user-ques.j2 new file mode 100644 index 000000000..b7d9a1305 --- /dev/null +++ b/backend/src/prompts/templates/answer-user-ques.j2 @@ -0,0 +1,54 @@ +You are an expert in providing accurate and complete answers to user queries. Your task is twofold: + +1. **Generate a detailed answer** to the user's question based on the provided content or context. +2. **Validate** if the generated answer directly addresses the user's question and is factually accurate. + +User's question is: +{{ question }} + +Once you generate an answer: +- **Check** if the answer completely and accurately addresses the user's question. +- **Determine** if the answer is valid, based on the content provided. + +Reply only in JSON format with the following structure: + +```json +{ + "answer": "The answer to the user's question, based on the content provided", + "is_valid": true or false, + "validation_reason": "A sentence explaining whether the answer is valid or not, and why" +} + + + +### **Explanation:** + +1. **Answer**: The LLM generates an answer based on the user’s question and the provided content. +2. **Validity Check**: The LLM checks if its generated answer is complete and correct. This could be based on factual accuracy, coverage of the query, or relevance to the user's question. +3. **Validation Reason**: The LLM explains why the answer is valid or invalid. + +### **Example of Usage:** + +#### **User’s Question:** +- **"What is Tesla's revenue every year since its creation?"** + +#### **Content Provided:** +- A table or a paragraph with data on Tesla's revenue for various years. + +#### **LLM’s Response:** + +```json +{ + "answer": "Tesla's revenue since its creation is: 2008: $15 million, 2009: $30 million, ..., 2023: $81 billion.", + "is_valid": true, + "validation_reason": "The answer includes Tesla's revenue for every year since its creation, based on the data provided." +} + +{ + "answer": "Tesla's revenue for 2010 to 2023 is available, but data for the earlier years is missing.", + "is_valid": false, + "validation_reason": "The answer is incomplete because data for Tesla's early years is missing." +} + + +Important: If the question is related to real time data, the LLM should provide is_valid is false. diff --git a/backend/src/prompts/templates/best-next-step.j2 b/backend/src/prompts/templates/best-next-step.j2 index 417d357c8..ff6ef6c98 100644 --- a/backend/src/prompts/templates/best-next-step.j2 +++ b/backend/src/prompts/templates/best-next-step.j2 @@ -22,11 +22,12 @@ Here is the list of Agents you can choose from: AGENT LIST: {{ list_of_agents }} -If the list of agents does not contain something suitable, you should say the agent is 'none'. ie. If question is 'general knowledge', 'personal' or a 'greeting'. +If the list of agents does not contain something suitable, you should say the agent is 'WebAgent'. ie. If question is 'general knowledge', 'personal' or a 'greeting'. ## Determine the next best step Your task is to pick one of the mentioned agents above to complete the task. If the same agent_name and task are repeated more than twice in the history, you must not pick that agent_name. +If mathematical processing (e.g., rounding or calculations) is needed, choose the MathsAgent. If file operations are needed, choose the FileAgent. Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications. diff --git a/backend/src/prompts/templates/best-tool.j2 b/backend/src/prompts/templates/best-tool.j2 index 51d93f5de..7452cf1a4 100644 --- a/backend/src/prompts/templates/best-tool.j2 +++ b/backend/src/prompts/templates/best-tool.j2 @@ -11,11 +11,16 @@ Trust the information below completely (100% accurate) Pick 1 tool (no more than 1) from the list below to complete this task. Fit the correct parameters from the task to the tool arguments. +Ensure that numerical values are formatted correctly, including the use of currency symbols (e.g., "£") and units of measurement (e.g., "million") if applicable. Parameters with required as False do not need to be fit. Add if appropriate, but do not hallucinate arguments for these parameters {{ tools }} +Important: +If the task involves financial data, ensure that all monetary values are expressed with appropriate currency (e.g., "£") and rounded to the nearest million if specified. +If the task involves scaling (e.g., thousands, millions), ensure that the extracted parameters reflect the appropriate scale (e.g., "£15 million", "£5000"). + From the task you should be able to extract the parameters. If it is data driven, it should be turned into a cypher query If none of the tools are appropriate for the task, return the following tool diff --git a/backend/src/prompts/templates/create-search-term.j2 b/backend/src/prompts/templates/create-search-term.j2 new file mode 100644 index 000000000..cc46787ac --- /dev/null +++ b/backend/src/prompts/templates/create-search-term.j2 @@ -0,0 +1,16 @@ +You are an expert at crafting Google search terms. Your goal is to generate an optimal search query based on the user's question to find the most relevant information on Google. + +Your entire purpose is to analyze the user's query, extract the essential keywords, and create a concise, well-structured search term that will yield the most accurate and useful results when used in a Google search. + +Ensure that the search query: + +Is relevant to the user’s question. +Contains the right combination of keywords. +Avoids unnecessary words, focusing only on what is critical for finding the right information. +User's question is: {{ question }} + +Reply only in JSON format, following this structure: +{ + "search_term": "The optimized Google search term based on the user's question", + "reasoning": "A sentence on why you chose that search term" +} diff --git a/backend/src/prompts/templates/find-info.j2 b/backend/src/prompts/templates/find-info.j2 new file mode 100644 index 000000000..8fb780133 --- /dev/null +++ b/backend/src/prompts/templates/find-info.j2 @@ -0,0 +1,17 @@ +You are an expert information extractor. Your goal is to find specific data from the content provided and answer the user's question directly. + +You will be given a user query and content scraped from the web. Your task is to carefully examine the content and extract the exact information relevant to the query. + +Ensure that your response is precise and focused, only providing the data that directly answers the user's question. + +User's question is: {{ question }} + +Below is the content scraped from the web: {{ content | replace("\n\n", "\n") }} + +Reply only in JSON format as follows: + +{ + "extracted_info": "The exact information that answers the user's query", + "reasoning": "A brief explanation of how the extracted information is relevant" +} + diff --git a/backend/src/prompts/templates/generate-cypher-query.j2 b/backend/src/prompts/templates/generate-cypher-query.j2 index 36ae24e6e..597df245c 100644 --- a/backend/src/prompts/templates/generate-cypher-query.j2 +++ b/backend/src/prompts/templates/generate-cypher-query.j2 @@ -30,4 +30,4 @@ When returning a value, always remove the `-` sign before the number. Here is the graph schema: {{ graph_schema }} -The current date and time is {{ current_date }} +The current date and time is {{ current_date }} and the currency of the data is GBP. diff --git a/backend/src/prompts/templates/graph-schema.j2 b/backend/src/prompts/templates/graph-schema.j2 deleted file mode 100644 index 38892655c..000000000 --- a/backend/src/prompts/templates/graph-schema.j2 +++ /dev/null @@ -1,156 +0,0 @@ -{ - "nodes": { - "labels": [ - { - "label": "Account", - "detail": "An Account is a unique user profile with specific identifiers and associated transactions.", - "cypher_representation": "(:Account)" - }, - { - "label": "Transaction", - "detail": "A Transaction is a record of a financial exchange between an Account and a Merchant, containing details like date, amount, etc.", - "cypher_representation": "(:Transaction)" - }, - { - "label": "Merchant", - "detail": "A Merchant is an entity that provides goods or services in exchange for payment, linked to Transactions.", - "cypher_representation": "(:Merchant)" - }, - { - "label": "Classification", - "detail": "A Classification is a category assigned to a Transaction, based on the type of purchase or service.", - "cypher_representation": "(:Classification)" - } - ] - }, - "properties": { - "node_properties": [ - { - "node_label": "Transaction", - "properties": [ - { - "name": "id", - "data_type": "String", - "detail": "Unique identifier for the transaction" - }, - { - "name": "amount", - "data_type": "Long", - "detail": "The amount of money involved in the transaction" - }, - { - "name": "description", - "data_type": "String", - "detail": "A short explanation or reason for the transaction" - }, - { - "name": "date", - "data_type": "DateTime", - "detail": "The date and time when the transaction occurred" - }, - { - "name": "type", - "data_type": "String", - "detail": "The category or type of the transaction. One of: DEBIT, CREDIT, TRANSFER" - } - ] - }, - { - "node_label": "Merchant", - "properties": [ - { - "name": "name", - "data_type": "String", - "detail": "The name of the merchant / company involved in the transaction" - } - ] - }, - { - "node_label": "Classification", - "properties": [ - { - "name": "name", - "data_type": "String", - "detail": "The category or classification of the transaction" - } - ] - }, - { - "node_label": "Account", - "properties": [ - { - "name": "name", - "data_type": "String", - "detail": "The name or identifier of the account involved in the transaction" - } - ] - } - ], - "relationship_properties": [ - { - "relationship_type": "[:PAID_BY]", - "properties": [ - { - "name": "transaction_id", - "data_type": "String", - "detail": "Represents the payment of a Transaction by an Account" - } - ] - }, - { - "relationship_type": "[:PAID_TO]", - "properties": [ - { - "name": "transaction_id", - "data_type": "String", - "detail": "Represents the payment for a Transaction received by a Merchant" - } - ] - }, - { - "relationship_type": "[:CLASSIFIED_AS]", - "properties": [ - { - "name": "category", - "data_type": "String", - "detail": "Represents the categorization or classification of a transaction" - } - ] - } - ] - }, - "relationships": { - "paths": [ - { - "label": "[:PAID_TO]", - "detail": "Represents a payment made to a merchant", - "cypher_representation": "(:Transaction)-[:PAID_TO]->(:Merchant)" - }, - { - "label": "[:PAID_BY]", - "detail": "Represents a transaction that is paid by a specific account", - "cypher_representation": "(:Transaction)-[:PAID_BY]->(:Account)" - }, - { - "label": "[:CLASSIFIED_AS]", - "detail": "Represents the classification of a node", - "cypher_representation": "(:Transaction)-[:CLASSIFIED_AS]->(:Classification)" - }, - { - "label": "[:PAID_TO]", - "detail": "Represents a payment made to a node", - "cypher_representation": "(:Transaction)-[:PAID_TO]->(:Merchant)" - }, - { - "label": "[:PAID_BY]", - "detail": "Represents the account that made a transaction *NOTE DIRECTION OF ARROW*", - "cypher_representation": "(:Transaction)-[:PAID_BY]->(:Account)" - }, - { - "label": "[:CLASSIFIED_AS]", - "detail": "Represents the classification of a transaction", - "cypher_representation": "(:Transaction)-[:CLASSIFIED_AS]->(:Classification)" - } - ] - } -} diff --git a/backend/src/prompts/templates/intent.j2 b/backend/src/prompts/templates/intent.j2 index 84361b217..a112f3221 100644 --- a/backend/src/prompts/templates/intent.j2 +++ b/backend/src/prompts/templates/intent.j2 @@ -1,27 +1,40 @@ You are an expert in determining the intent behind a user's question. - The question is: {{ question }} - -The task is to comprehend the intention of the question. The question can be composed of different intents and when it is the case, examine all intents one by one to determine which one to tackle first as you may need the data gathered from a secondary intent to perform the first intent. -You are NOT ALLOWED to make up sample data or example values. Only use concrete data for which you can name the source. -Based on this understanding, the following query must be formulated to extract the necessary data, which can then be used to address the question. +The previous chat history is: - -Specify an operation type under the operation key; here are a few examples: +{{ chat_history }} -* "literal search" - This should be used when the user is looking to find precise information, such as known facts -* "relevancy search" - This should be used when the user is looking to find something that is not a literal and is fuzzy -* "filter + aggregation" - This should be used when they want something like a count, where there will be only 1 number returned -* "filter + aggregation + sort" - This should be used when multiple numbers will be returned -* "filter + sort" - This should be used when no aggregation is required e.g. count +Your task is to accurately comprehend the intentions behind the current question. +The question can be composed of different intents and may depend on the context provided by the previous question and its response. +- You must evaluate whether the current question is directly related to or dependent on the information from the previous interaction. +- If the current question builds on the previous one, make sure to use the relevant data from the previous response to inform the current query. -Examples: +The question may contain multiple intents. Examine each intent and determine the order in which they should be tackled, ensuring each intent is addressed logically. If one intent depends on data from another, sequence them accordingly. +Use the following guidelines: + +1. Determine distinct intents in the question. +2. For each intent, specify: + - The exact operation required (e.g., "literal search", "filter + aggregation"). + - The category of the question (e.g., "data-driven", "general knowledge"). + - Any specific parameters or conditions that apply. + - If related to the previous response, include parameters derived from the previous interaction. +3. Sequence the intents logically if there are multiple, ensuring any dependent intents are handled last. +4. For each intent, clarify the operation, aggregation, sorting, and any timeframe or other parameters. +5. Avoid conflating intents: If a user's query asks for data retrieval and its visualization, treat these as separate operations. +6. Use the chat history to figure out the correct context if the user's question is a bit vague. + +Examples of common operations: +- Literal search for factual information. +- Filter + aggregation for tasks like counting or summing. +- Data transformation for numerical operations like rounding. + +Examples Q: How much did I spend with Amazon this month? Response: {"query":"How much did I spend with Amazon this month?","user_intent":"sum amount spent","questions":[{"query":"How much did I spend with Amazon this month?","question_intent":"calculate total expenses","operation":"filter + aggregation","question_category": "data driven","parameters":[{"type":"company","value":"Amazon"}],"timeframe":"this month","aggregation":"sum","sort_order":"none"}]} @@ -53,3 +66,22 @@ Response: Q: Find the schedule of the local train station. Response: {"query":"Find the schedule of the local train station.","user_intent":"find train schedule","questions":[{"query":"Find the schedule of the local train station.","question_intent":"retrieve train schedule from web","operation":"online search","question_category":"search online","parameters":[{"type":"train station","value":"local"}],"sort_order":"none"}]} + +Q: What are the different subscriptions with Netflix? Show me the results in a chart. +Response: +{"query": "What are the different subscriptions with Netflix? Show me the results in a chart.", "user_intent": "find and display subscription information", "questions": [{"query": "What are the different subscriptions with Netflix?", "question_intent": "retrieve subscription information", "operation": "literal search", "question_category": "data driven", "parameters": [{"type": "company", "value": "Netflix"}], "aggregation": "none", "sort_order": "none", "timeframe": "none"}, {"query": "Show me the results in a chart", "question_intent": "display information in a chart", "operation": "data visualization", "question_category": "data presentation", "parameters": [], "aggregation": "none", "sort_order": "none", "timeframe": "none"}]} + +Q: Show me a chart of different subscription prices with Netflix? +Response: +{"query": "Show me a chart of different subscription prices with Netflix?", "user_intent": "retrieve and visualize subscription data", "questions": [{"query": "What are the different subscription prices with Netflix?", "question_intent": "retrieve subscription pricing information", "operation": "literal search", "question_category": "data driven", "parameters": [{"type": "company", "value": "Netflix"}], "aggregation": "none", "sort_order": "none", "timeframe": "none"}, {"query": "Show me the results in a chart", "question_intent": "display subscription pricing information in a chart", "operation": "data visualization", "question_category": "data presentation", "parameters": [], "aggregation": "none", "sort_order": "none", "timeframe": "none"}]} + +Finally, if no tool fits the task, return the following: + +{ + "tool_name": "None", + "tool_parameters": "{}", + "reasoning": "No tool was appropriate for the task" +} + +Important: +Please always create the last intent to append the retrieved info in a 'conversation-history.txt' file and make sure this history file is always named 'conversation-history.txt' diff --git a/backend/src/prompts/templates/math-solver.j2 b/backend/src/prompts/templates/math-solver.j2 new file mode 100644 index 000000000..122e88d2e --- /dev/null +++ b/backend/src/prompts/templates/math-solver.j2 @@ -0,0 +1,26 @@ +You are an expert in performing mathematical operations. You are highly skilled in handling various mathematical queries such as performing arithmetic operations, applying formulas, and expressing numbers in different formats as requested by the user. + +You will be given a mathematical query, and your task is to solve the query based on the provided information. Ensure that you apply the appropriate mathematical principles to deliver an exact result. **Only convert numbers to millions if explicitly requested by the user.** Otherwise, return the result as is, without unnecessary conversions. + +Make sure to perform the calculations step by step when necessary, and return the final result clearly. + +User's query is: +{{ query }} + +Reply only in json with the following format: + +{ + "result": "The final result of the mathematical operation, without unnecessary conversion to millions or any other format unless explicitly requested", + "steps": "A breakdown of the steps involved in solving the query (if applicable)", + "reasoning": "A sentence on why this result is accurate" +} + +Following is an example of the query and the expected response format: +query: Round 81.462 billion to the nearest million + +{ + "result": "81,462 million", + "steps": "1. Convert 81.462 billion to million by multiplying by 1000. Round the result to the nearest million.", + "reasoning": "Rounding to the nearest million ensures that the result is represented in a more practical figure, without exceeding or falling short of the actual value." +} + diff --git a/backend/src/prompts/templates/node-property-cypher-query.j2 b/backend/src/prompts/templates/node-property-cypher-query.j2 index 0f3a5bf2b..0d22bc0fb 100644 --- a/backend/src/prompts/templates/node-property-cypher-query.j2 +++ b/backend/src/prompts/templates/node-property-cypher-query.j2 @@ -7,7 +7,7 @@ WITH detail : "" }) as props RETURN COLLECT({ - node_label: node, + label: node, cypher_representation : "(:" + node + ")", properties: props }) AS nodeProperties diff --git a/backend/src/prompts/templates/nodes-query.j2 b/backend/src/prompts/templates/nodes-query.j2 deleted file mode 100644 index cd1e2207e..000000000 --- a/backend/src/prompts/templates/nodes-query.j2 +++ /dev/null @@ -1,6 +0,0 @@ -call db.labels() yield label -return collect({ - label: label, - cypher_representation : "(:" + label + ")", - detail: "A " + label + " is a..." -}) AS nodes diff --git a/backend/src/prompts/templates/relationship-property-cypher-query.j2 b/backend/src/prompts/templates/relationship-property-cypher-query.j2 index f03d64b9a..952c11ce0 100644 --- a/backend/src/prompts/templates/relationship-property-cypher-query.j2 +++ b/backend/src/prompts/templates/relationship-property-cypher-query.j2 @@ -4,7 +4,7 @@ WITH COLLECT({ name: propertyName, data_type: propertyTypes, - detail: "A " + propertyName + " is a.. " + detail: "" }) AS props RETURN COLLECT({ relationship_type: "[" + REPLACE(rel, "`", "") + "]", diff --git a/backend/src/prompts/templates/relationships-query.j2 b/backend/src/prompts/templates/relationships-query.j2 index f59ca8343..54e1a73e1 100644 --- a/backend/src/prompts/templates/relationships-query.j2 +++ b/backend/src/prompts/templates/relationships-query.j2 @@ -1,12 +1 @@ -CALL apoc.meta.stats() YIELD relTypes -WITH relTypes, keys(relTypes) AS relTypeKeys -UNWIND relTypeKeys AS relTypeKey -WITH relTypeKey, relTypes[relTypeKey] AS count -WHERE relTypeKey CONTAINS ")->(:" - OR relTypeKey CONTAINS "(:" -WITH collect({ - label: split(split(relTypeKey, "-")[1], ">")[0], - cypher_representation: relTypeKey, - detail: "" -}) AS paths -RETURN paths \ No newline at end of file +call db.schema.visualization diff --git a/backend/src/prompts/templates/summariser.j2 b/backend/src/prompts/templates/summariser.j2 index 4a54cd520..33a1bf03f 100644 --- a/backend/src/prompts/templates/summariser.j2 +++ b/backend/src/prompts/templates/summariser.j2 @@ -6,12 +6,11 @@ You will be passed a user query and the content scraped from the web. You need t Ensure the summary is clear, well-structured, and directly addresses the user's query. - User's question is: {{ question }} Below is the content scraped from the web: -{{ content }} +{{ content | replace("\n\n", "\n") }} # Adding this will introduce breaks between paragraphs Reply only in json with the following format: @@ -19,10 +18,3 @@ Reply only in json with the following format: "summary": "The summary of the content that answers the user's query", "reasoning": "A sentence on why you chose that summary" } - -e.g. -Task: What is the capital of England -{ - "summary": "The capital of England is London.", - "reasoning": "London is widely known as the capital of England, a fact mentioned in various authoritative sources and geographical references." -} diff --git a/backend/src/prompts/templates/tool-selection-format.j2 b/backend/src/prompts/templates/tool-selection-format.j2 index 542d2f7c1..c1cba32ca 100644 --- a/backend/src/prompts/templates/tool-selection-format.j2 +++ b/backend/src/prompts/templates/tool-selection-format.j2 @@ -1,4 +1,5 @@ -Reply only in json with the following format: +Reply only in json with the following format, in the tool_parameters please include the currency and measuring scale used in the content provided.: + { "tool_name": "the exact string name of the tool chosen", diff --git a/backend/src/prompts/templates/validator.j2 b/backend/src/prompts/templates/validator.j2 index 8fe5427db..54256b525 100644 --- a/backend/src/prompts/templates/validator.j2 +++ b/backend/src/prompts/templates/validator.j2 @@ -24,6 +24,20 @@ Answer: Last month you spend £64.21 on Spotify Response: False Reasoning: The answer is for Spotify not Amazon. +Task: Please find tesla's revenue every year since its creation. +Answer: Tesla's annual revenue history from FY 2008 to FY 2023 is available, with figures for 2008 through 2020 taken from previous annual reports. +Response: False +Reasoning: The answer is not providing any actual figures but just talk about the figures. + +Task: Please find tesla's revenue every year since its creation in the US dollars. +Answer: Tesla's annual revenue in USD since its creation is as follows: 2024 (TTM) $75.92 billion, 2023 $75.95 billion, 2022 $67.33 billion, 2021 $39.76 billion, 2020 $23.10 billion, 2019 $18.52 billion, 2018 $16.81 billion, 2017 $8.70 billion, 2016 $5.67 billion, 2015 $2.72 billion, 2014 $2.05 billion, 2013 $1.21 billion, 2012 $0.25 billion, 2011 $0.13 billion, 2010 $75.88 million, 2009 $69.73 million. +Response: False +Reasoning: The answer is providing the revenue in GBP not USD. + +Task: Round the following numbers to the nearest million dollars: 96.77B, 81.46B, 53.82B, 31.54B, 24.58B, 21.46B +Answer: 96,770 million, 81,460 million, 53,820 million, 31,540 million, 24,580 million, 21,460 million +Reponse: True + You must always return a single boolean value as the response. Do not return any additional information, just the boolean value. diff --git a/backend/src/supervisors/supervisor.py b/backend/src/supervisors/supervisor.py index f9a87478c..85c030e00 100644 --- a/backend/src/supervisors/supervisor.py +++ b/backend/src/supervisors/supervisor.py @@ -20,25 +20,30 @@ async def solve_all(intent_json) -> None: for question in questions: try: - (agent_name, answer) = await solve_task(question, get_scratchpad()) + (agent_name, answer, status) = await solve_task(question, get_scratchpad()) update_scratchpad(agent_name, question, answer) + if status == "error": + raise Exception(answer) except Exception as error: update_scratchpad(error=error) -async def solve_task(task, scratchpad, attempt=0) -> Tuple[str, str]: +async def solve_task(task, scratchpad, attempt=0) -> Tuple[str, str, str]: if attempt == 5: raise Exception(unsolvable_response) agent = await get_agent_for_task(task, scratchpad) + logger.info(f"Agent selected: {agent}") if agent is None: raise Exception(no_agent_response) + logger.info(f"Task is {task}") answer = await agent.invoke(task) parsed_json = json.loads(answer) + status = parsed_json.get('status', 'success') ignore_validation = parsed_json.get('ignore_validation', '') answer_content = parsed_json.get('content', '') if(ignore_validation == 'true') or await is_valid_answer(answer_content, task): - return (agent.name, answer_content) + return (agent.name, answer_content, status) return await solve_task(task, scratchpad, attempt + 1) diff --git a/backend/src/utils/config.py b/backend/src/utils/config.py index 213a4feb8..127b445a3 100644 --- a/backend/src/utils/config.py +++ b/backend/src/utils/config.py @@ -3,6 +3,7 @@ default_frontend_url = "http://localhost:8650" default_neo4j_uri = "bolt://localhost:7687" +default_files_directory = "files" class Config(object): @@ -25,6 +26,7 @@ def __init__(self): self.maths_agent_llm = None self.web_agent_llm = None self.chart_generator_llm = None + self.file_agent_llm = None self.router_llm = None self.validator_agent_model = None self.intent_agent_model = None @@ -33,6 +35,8 @@ def __init__(self): self.chart_generator_model = None self.web_agent_model = None self.router_model = None + self.files_directory = default_files_directory + self.file_agent_model = None self.load_env() def load_env(self): @@ -49,6 +53,7 @@ def load_env(self): self.neo4j_uri = os.getenv("NEO4J_URI", default_neo4j_uri) self.neo4j_user = os.getenv("NEO4J_USERNAME") self.neo4j_password = os.getenv("NEO4J_PASSWORD") + self.files_directory = os.getenv("FILES_DIRECTORY", default_files_directory) self.azure_storage_connection_string = os.getenv("AZURE_STORAGE_CONNECTION_STRING") self.azure_storage_container_name = os.getenv("AZURE_STORAGE_CONTAINER_NAME") self.azure_initial_data_filename = os.getenv("AZURE_INITIAL_DATA_FILENAME") @@ -57,6 +62,7 @@ def load_env(self): self.validator_agent_llm = os.getenv("VALIDATOR_AGENT_LLM") self.datastore_agent_llm = os.getenv("DATASTORE_AGENT_LLM") self.chart_generator_llm = os.getenv("CHART_GENERATOR_LLM") + self.file_agent_llm = os.getenv("FILE_AGENT_LLM") self.web_agent_llm = os.getenv("WEB_AGENT_LLM") self.maths_agent_llm = os.getenv("MATHS_AGENT_LLM") self.router_llm = os.getenv("ROUTER_LLM") @@ -68,6 +74,7 @@ def load_env(self): self.chart_generator_model = os.getenv("CHART_GENERATOR_MODEL") self.maths_agent_model = os.getenv("MATHS_AGENT_MODEL") self.router_model = os.getenv("ROUTER_MODEL") + self.file_agent_model = os.getenv("FILE_AGENT_MODEL") except FileNotFoundError: raise FileNotFoundError("Please provide a .env file. See the Getting Started guide on the README.md") except Exception: diff --git a/backend/src/utils/semantic_layer_builder.py b/backend/src/utils/semantic_layer_builder.py index e2418c3f7..c030e6237 100644 --- a/backend/src/utils/semantic_layer_builder.py +++ b/backend/src/utils/semantic_layer_builder.py @@ -1,80 +1,159 @@ -# NOT NEEDED CURRENTLY AS THE RETURNED VALUE OF GRAPH SCHEMA IS STORED STATICALLY AS A JINJA TEMPLATE -# THIS FILE IS CURRENTLY BROKEN BUT AS UNUSED THE FOLLOWING LINE IS SUPRESSING ERRORS -# REMOVE NEXT LINE BEFORE WORKING ON FILE -# pyright: reportAttributeAccessIssue=none -from src.llm import call_model from src.utils.graph_db_utils import execute_query import logging from src.prompts import PromptEngine import json -import ast logger = logging.getLogger(__name__) engine = PromptEngine() +relationship_property_query = engine.load_prompt("relationship-property-cypher-query") +node_property_query = engine.load_prompt("node-property-cypher-query") -def get_semantic_layer(): - finalised_graph_structure = {"nodes": {}, "properties": {}} +neo4j_graph_why_prompt = engine.load_prompt("neo4j-graph-why") - neo4j_graph_why_prompt = engine.load_prompt("neo4j-graph-why") +neo4j_nodes_understanding_prompt = engine.load_prompt( + "neo4j-nodes-understanding", neo4j_graph_why_prompt=neo4j_graph_why_prompt +) - relationship_query = engine.load_prompt("relationships-query") +neo4j_relationship_property_prompt = engine.load_prompt( + "neo4j-property-intent-prompt", neo4j_graph_why_prompt=neo4j_graph_why_prompt +) - node_query = engine.load_prompt("nodes-query") +neo4j_node_property_prompt = engine.load_prompt( + "neo4j-node-property", neo4j_graph_why_prompt=neo4j_graph_why_prompt +) +relationship_query = engine.load_prompt("relationships-query") - relationship_property_query = engine.load_prompt("relationship-property-cypher-query") +neo4j_relationships_understanding_prompt = engine.load_prompt( + "neo4j-relationship-understanding", neo4j_graph_why_prompt=neo4j_graph_why_prompt +) - node_property_query = engine.load_prompt("node-property-cypher-query") +async def get_semantic_layer(llm, model): + finalised_graph_structure = {"nodes": {}, "properties": {}} - neo4j_relationships_understanding_prompt = engine.load_prompt( - "neo4j-relationship-understanding", neo4j_graph_why_prompt=neo4j_graph_why_prompt - ) + relationship_result = execute_query(relationship_query) + payload = relationship_result[0] + + nodes = [] + relationships_dict = {} + + # Convert nodes + for node in payload['nodes']: + nodes.append({ + "cypher_representation": f"(:{node['name']})", + "label": node['name'], + "indexes": node.get('indexes', []), + "constraints": node.get('constraints', []) + }) + + # Convert relationships + for relationship in payload['relationships']: + start_node = relationship[0]['name'] + relationship_type = relationship[1] + end_node = relationship[2]['name'] + path = f"(:{start_node})-[:{relationship_type}]->(:{end_node})" + + if relationship_type not in relationships_dict: + relationships_dict[relationship_type] = { + "cypher_representation": f"[:{relationship_type}]", + "type": relationship_type, + "paths": [] + } + + relationships_dict[relationship_type]["paths"].append({ + "path": path, + "detail": "" + }) + # Convert relationships_dict to a list + relationships = list(relationships_dict.values()) + + finalised_graph_structure = { + "nodes": nodes, + "relationships": relationships + } + json.dumps(finalised_graph_structure) + + await enrich_relationships(llm, model, finalised_graph_structure) + await enrich_nodes(llm, model, finalised_graph_structure) + await enriched_rel_properties(llm, model, finalised_graph_structure) + await enrich_nodes_properties(llm, model, finalised_graph_structure) + + return finalised_graph_structure + +async def enrich_relationships(llm, model, finalised_graph_structure): + relationships = finalised_graph_structure['relationships'] + enriched_relationships_list = [] + + for relationship in relationships: + enriched_relationship = await llm.chat(model, neo4j_relationships_understanding_prompt, str(relationship), + return_json=True) + enriched_relationships_list.append(json.loads(enriched_relationship)) + + finalised_graph_structure['relationships'] = enriched_relationships_list + logger.debug(f"finalised graph structure with enriched relationships: {finalised_graph_structure}") + +async def enrich_nodes(llm, model, finalised_graph_structure): + neo4j_data = finalised_graph_structure['nodes'] + print(f"neo4j data: {neo4j_data}") + enriched_nodes = await llm.chat(model, neo4j_nodes_understanding_prompt, str(neo4j_data), return_json=True) + enriched_nodes = json.loads(enriched_nodes) + json.dumps(enriched_nodes) + finalised_graph_structure['nodes'] = enriched_nodes + logger.debug(f"finalised graph structure: {finalised_graph_structure}") + print(f"finalised_graph_structure with nodes: {finalised_graph_structure}") + +async def enriched_rel_properties(llm, model, finalised_graph_structure): + properties_result = execute_query(relationship_property_query) + rel_properties_neo4j = properties_result[0] + cleaned_rel_properties = [] - neo4j_nodes_understanding_prompt = engine.load_prompt( - "neo4j-nodes-understanding", neo4j_graph_why_prompt=neo4j_graph_why_prompt - ) + for rel_property in rel_properties_neo4j['relProperties']: + cleaned_properties = [prop for prop in rel_property['properties'] if prop['name'] is not None] + if cleaned_properties: + rel_property['properties'] = cleaned_properties + cleaned_rel_properties.append(rel_property) - neo4j_relationship_property_prompt = engine.load_prompt( - "neo4j-property-intent-prompt", neo4j_graph_why_prompt=neo4j_graph_why_prompt - ) + rel_properties_neo4j = {'relProperties': cleaned_rel_properties} + json.dumps(rel_properties_neo4j) - neo4j_node_property_prompt = engine.load_prompt( - "neo4j-node-property", neo4j_graph_why_prompt=neo4j_graph_why_prompt - ) + enriched_rel_properties = await llm.chat(model, neo4j_relationship_property_prompt, str(rel_properties_neo4j), + return_json=True) + enriched_rel_properties = json.loads(enriched_rel_properties) - # Fetch and enrich relationships - relationship_result = execute_query(relationship_query) - relationships_neo4j = relationship_result[0] - enriched_relationships = call_model(neo4j_relationships_understanding_prompt, str(relationships_neo4j)) - enriched_relationships = json.dumps(enriched_relationships) - enriched_relationships = json.loads(enriched_relationships) - finalised_graph_structure["relationships"] = enriched_relationships if enriched_relationships else {} - logger.debug(f"Finalised graph structure with enriched relationships: {finalised_graph_structure}") - - # Fetch and enrich nodes - nodes_neo4j_result = execute_query(node_query) - nodes_neo4j = nodes_neo4j_result[0] - enriched_nodes = call_model(neo4j_nodes_understanding_prompt, str(nodes_neo4j)) - enriched_nodes = ast.literal_eval(enriched_nodes) - finalised_graph_structure["nodes"]["labels"] = enriched_nodes["nodes"] - logger.debug(f"Finalised graph structure with enriched nodes: {finalised_graph_structure}") - - # Fetch and enrich relationship properties - properties_result = execute_query(relationship_property_query) - rel_properties_neo4j = properties_result[0] - enriched_rel_properties = call_model(neo4j_relationship_property_prompt, str(rel_properties_neo4j)) - enriched_rel_properties = ast.literal_eval(enriched_rel_properties) - finalised_graph_structure["properties"]["relationship_properties"] = enriched_rel_properties["relProperties"] - logger.debug(f"Finalised graph structure with enriched relationship properties: {finalised_graph_structure}") + # Merge properties + for new_rel in enriched_rel_properties["relProperties"]: + relationship_type = new_rel["relType"] + properties_to_add = new_rel["property"] + + for rel in finalised_graph_structure["relationships"]: + if rel["cypher_representation"] == relationship_type: + if "properties" not in rel: + rel["property"] = [] + rel["property"] = properties_to_add - # Fetch and enrich node properties + logger.debug(f"finalised graph structure with enriched properties: {finalised_graph_structure}") + +async def enrich_nodes_properties(llm, model, finalised_graph_structure): node_properties_neo4j_result = execute_query(node_property_query) node_properties_neo4j = node_properties_neo4j_result[0] - enriched_node_properties = call_model(neo4j_node_property_prompt, str(node_properties_neo4j)) - enriched_node_properties = ast.literal_eval(enriched_node_properties) - finalised_graph_structure["properties"]["node_properties"] = enriched_node_properties["nodeProperties"] - logger.debug(f"Finalised graph structure with enriched node properties: {finalised_graph_structure}") - - graph_schema = json.dumps(finalised_graph_structure, separators=(",", ":")) - return graph_schema + filtered_payload = { + 'nodeProperties': [ + node for node in node_properties_neo4j['nodeProperties'] + if all(prop['data_type'] is not None and prop['name'] is not None for prop in node['properties']) + ] + } + enriched_node_properties = await llm.chat(model, neo4j_node_property_prompt, str(filtered_payload), + return_json=True) + enriched_node_properties = json.loads(enriched_node_properties) + + for new_node in enriched_node_properties["nodeProperties"]: + label = new_node["label"] + properties_to_add = new_node["properties"] + + for node in finalised_graph_structure["nodes"]: + if node["label"] == label: + if "properties" not in node: + node["properties"] = [] + node["properties"] = properties_to_add + logger.debug(f"finalised graph structure with enriched nodes: {finalised_graph_structure}") diff --git a/backend/src/utils/web_utils.py b/backend/src/utils/web_utils.py index 2edaed606..1839b03ee 100644 --- a/backend/src/utils/web_utils.py +++ b/backend/src/utils/web_utils.py @@ -13,7 +13,7 @@ engine = PromptEngine() -def search_urls(search_query, num_results=10) -> str: +async def search_urls(search_query, num_results=10) -> str: logger.info(f"Searching the web for: {search_query}") urls = [] try: @@ -43,8 +43,8 @@ async def scrape_content(url, limit=100000) -> str: async with aiohttp.request("GET", url) as response: response.raise_for_status() soup = BeautifulSoup(await response.text(), "html.parser") - paragraphs = soup.find_all("p") - content = " ".join([para.get_text() for para in paragraphs]) + paragraphs_and_tables = soup.find_all(["p", "table", "h1", "h2", "h3", "h4", "h5", "h6"]) + content = "\n".join([tag.get_text() for tag in paragraphs_and_tables]) return json.dumps( { "status": "success", @@ -62,6 +62,47 @@ async def scrape_content(url, limit=100000) -> str: } ) +async def create_search_term(search_query, llm, model) -> str: + try: + summariser_prompt = engine.load_prompt("create-search-term", question=search_query) + response = await llm.chat(model, summariser_prompt, "", return_json=True) + return json.dumps( + { + "status": "success", + "response": response, + "error": None, + } + ) + except Exception as e: + logger.error(f"Error during create search term: {e}") + return json.dumps( + { + "status": "error", + "response": None, + "error": str(e), + } + ) + +async def answer_user_ques(search_query, llm, model) -> str: + try: + summariser_prompt = engine.load_prompt("answer-user-ques", question=search_query) + response = await llm.chat(model, summariser_prompt, "", return_json=True) + return json.dumps( + { + "status": "success", + "response": response, + "error": None, + } + ) + except Exception as e: + logger.error(f"Error during create search term: {e}") + return json.dumps( + { + "status": "error", + "response": None, + "error": str(e), + } + ) async def summarise_content(search_query, contents, llm, model) -> str: try: @@ -96,7 +137,51 @@ async def summarise_pdf_content(contents, llm, model) -> str: } ) except Exception as e: - logger.error(f"Error during summarisation: {e}") + logger.error(f"Error during summarisation of PDF: {e}") + return json.dumps( + { + "status": "error", + "response": None, + "error": str(e), + } + ) + +async def perform_math_operation_util(math_query, llm, model) -> str: + try: + math_prompt = engine.load_prompt("math-solver", query=math_query) + response = await llm.chat(model, math_prompt, "", return_json=True) + logger.info(f"Math operation response: {response}") + return json.dumps( + { + "status": "success", + "response": response, # math result + "error": None, + } + ) + except Exception as e: + logger.error(f"Error during math operation: {e}") + return json.dumps( + { + "status": "error", + "response": None, + "error": str(e), + } + ) + + +async def find_info(content, question, llm, model) -> str: + try: + find_info_prompt = engine.load_prompt("find-info", question=question, content=content) + response = await llm.chat(model, find_info_prompt, "", return_json=True) + return json.dumps( + { + "status": "success", + "response": response, + "error": None, + } + ) + except Exception as e: + logger.error(f"Error during finding info operation: {e}") return json.dumps( { "status": "error", diff --git a/backend/src/websockets/confirmations_manager.py b/backend/src/websockets/confirmations_manager.py new file mode 100644 index 000000000..0cf9fbcb7 --- /dev/null +++ b/backend/src/websockets/confirmations_manager.py @@ -0,0 +1,43 @@ +import logging +from typing import Dict +import uuid +from string import Template + +logger = logging.getLogger(__name__) + + +class ConfirmationsManager: + _open_confirmations: Dict[uuid.UUID, bool | None] = {} + _ERROR_MESSAGE = Template(" Confirmation with id '$confirmation_id' not found") + + def add_confirmation(self, confirmation_id: uuid.UUID): + self._open_confirmations[confirmation_id] = None + logger.info(f"Confirmation Added: {self._open_confirmations}") + + def get_confirmation_state(self, confirmation_id: uuid.UUID) -> bool | None: + if confirmation_id in self._open_confirmations: + return self._open_confirmations[confirmation_id] + else: + raise Exception( + "Cannot get confirmation." + self._ERROR_MESSAGE.substitute(confirmation_id=confirmation_id) + ) + + def update_confirmation(self, confirmation_id: uuid.UUID, value: bool): + if confirmation_id in self._open_confirmations: + self._open_confirmations[confirmation_id] = value + else: + raise Exception( + "Cannot update confirmation." + self._ERROR_MESSAGE.substitute(confirmation_id=confirmation_id) + ) + + def delete_confirmation(self, confirmation_id: uuid.UUID): + if confirmation_id in self._open_confirmations: + del self._open_confirmations[confirmation_id] + logger.info(f"Confirmation Deleted: {self._open_confirmations}") + else: + raise Exception( + "Cannot delete confirmation." + self._ERROR_MESSAGE.substitute(confirmation_id=confirmation_id) + ) + + +confirmations_manager = ConfirmationsManager() diff --git a/backend/src/websockets/connection_manager.py b/backend/src/websockets/connection_manager.py index aac7a0b0a..6ca5e6b13 100644 --- a/backend/src/websockets/connection_manager.py +++ b/backend/src/websockets/connection_manager.py @@ -1,4 +1,3 @@ -import json import logging from typing import Any, Dict, List from fastapi import WebSocket @@ -9,6 +8,7 @@ logger = logging.getLogger(__name__) + def parse_message(message: Dict[str, Any]) -> Message: data = message.get("data") or None return Message(type=message["type"], data=data) @@ -46,11 +46,12 @@ async def handle_message(self, ws: WebSocket, message: Message): async def broadcast(self, message: Message): for ws in self.websockets: if ws.application_state == WebSocketState.CONNECTED: - await ws.send_json(json.dumps({"type": message.type.value, "data": message.data})) + await ws.send_json({"type": message.type.value, "data": message.data}) async def send_chart(self, data: Dict[str, Any]): for ws in self.websockets: if ws.application_state == WebSocketState.CONNECTED: await ws.send_json(data) + connection_manager = ConnectionManager() diff --git a/backend/src/websockets/message_handlers.py b/backend/src/websockets/message_handlers.py index 968c16502..5d2d3faab 100644 --- a/backend/src/websockets/message_handlers.py +++ b/backend/src/websockets/message_handlers.py @@ -1,9 +1,11 @@ import asyncio import json import logging +from uuid import UUID from fastapi import WebSocket from typing import Callable from .types import Handlers, MessageTypes +from src.websockets.confirmations_manager import confirmations_manager logger = logging.getLogger(__name__) @@ -38,4 +40,30 @@ def on_chat(websocket: WebSocket, disconnect: Callable, data: str | None): logger.info(f"Chat message: {data}") -handlers: Handlers = {MessageTypes.PING: create_on_ping(), MessageTypes.CHAT: on_chat} +def on_confirmation(websocket: WebSocket, disconnect: Callable, data: str | None): + if data is None: + logger.warning("Confirmation response did not include data") + return + if ":" not in data: + logger.warning("Seperator (':') not present in confirmation") + return + sections = data.split(":") + try: + id = UUID(sections[0]) + except ValueError: + logger.warning("Received invalid id") + return + if sections[1] != "y" and sections[1] != "n": + logger.warning("Received invalid value") + return + try: + confirmations_manager.update_confirmation(id, sections[1] == "y") + except Exception as e: + logger.warning(f"Could not update confirmation: '{e}'") + + +handlers: Handlers = { + MessageTypes.PING: create_on_ping(), + MessageTypes.CHAT: on_chat, + MessageTypes.CONFIRMATION: on_confirmation, +} diff --git a/backend/src/websockets/types.py b/backend/src/websockets/types.py index 182b8d3e2..20d4bd83c 100644 --- a/backend/src/websockets/types.py +++ b/backend/src/websockets/types.py @@ -9,8 +9,9 @@ class MessageTypes(Enum): PING = "ping" PONG = "pong" CHAT = "chat" - LOG = "log" + LOG = "log" IMAGE = "image" + CONFIRMATION = "confirmation" @dataclass diff --git a/backend/src/websockets/user_confirmer.py b/backend/src/websockets/user_confirmer.py new file mode 100644 index 000000000..fc59ead24 --- /dev/null +++ b/backend/src/websockets/user_confirmer.py @@ -0,0 +1,45 @@ +import asyncio +import logging +import uuid +from src.websockets.types import Message, MessageTypes +from .connection_manager import connection_manager +from src.websockets.confirmations_manager import ConfirmationsManager + +logger = logging.getLogger(__name__) + + +class UserConfirmer: + _POLL_RATE_SECONDS = 0.5 + _TIMEOUT_SECONDS = 60.0 + _CONFIRMATIONS_MANAGER: ConfirmationsManager + + def __init__(self, manager: ConfirmationsManager): + self.confirmations_manager = manager + + async def confirm(self, msg: str) -> bool: + id = uuid.uuid4() + self.confirmations_manager.add_confirmation(id) + await self._send_confirmation(id, msg) + try: + async with asyncio.timeout(self._TIMEOUT_SECONDS): + return await self._check_confirmed(id) + except TimeoutError: + logger.warning(f"Confirmation with id {id} timed out.") + self.confirmations_manager.delete_confirmation(id) + return False + + async def _check_confirmed(self, id: uuid.UUID) -> bool: + while True: + try: + state = self.confirmations_manager.get_confirmation_state(id) + if isinstance(state, bool): + self.confirmations_manager.delete_confirmation(id) + return state + except Exception: + return False + await asyncio.sleep(self._POLL_RATE_SECONDS) + + async def _send_confirmation(self, id: uuid.UUID, msg: str): + data = f"{str(id)}:{msg}" + message = Message(MessageTypes.CONFIRMATION, data) + await connection_manager.broadcast(message) diff --git a/backend/tests/agents/chart_generator_agent_test.py b/backend/tests/agents/chart_generator_agent_test.py index 8bfb7629e..b3d404336 100644 --- a/backend/tests/agents/chart_generator_agent_test.py +++ b/backend/tests/agents/chart_generator_agent_test.py @@ -1,61 +1,143 @@ from io import BytesIO -import unittest from unittest.mock import patch, AsyncMock, MagicMock import pytest from src.agents.chart_generator_agent import generate_chart +import base64 +import matplotlib.pyplot as plt +from PIL import Image +import json +from src.agents.chart_generator_agent import sanitise_script +@pytest.mark.asyncio +@patch("src.agents.chart_generator_agent.engine.load_prompt") +@patch("src.agents.chart_generator_agent.sanitise_script", new_callable=MagicMock) +async def test_generate_code_success(mock_sanitise_script, mock_load_prompt): + llm = AsyncMock() + model = "mock_model" -class TestGenerateChartAgent(unittest.TestCase): - def setUp(self): - self.llm = AsyncMock() - self.model = "mock_model" - self.details_to_generate_chart_code = "details to generate chart code" - self.generate_chart_code_prompt = "generate chart code prompt" + mock_load_prompt.side_effect = [ + "details to create chart code prompt", + "generate chart code prompt" + ] - @pytest.mark.asyncio - @patch("src.agents.chart_generator_agent.engine.load_prompt") - @patch("src.agents.chart_generator_agent.sanitise_script") - async def test_generate_code_success(self, mock_sanitise_script, mock_load_prompt): - mock_load_prompt.side_effect = [self.details_to_generate_chart_code, self.generate_chart_code_prompt] - self.llm.chat.return_value = "generated code" - mock_sanitise_script.return_value = """ + llm.chat.return_value = "generated code" + return_string = mock_sanitise_script.return_value = """ import matplotlib.pyplot as plt fig = plt.figure() plt.plot([1, 2, 3], [4, 5, 6]) - """ + plt.switch_backend('Agg') + + def mock_exec_side_effect(script, globals=None, locals=None): + if script == return_string: + fig = plt.figure() + plt.plot([1, 2, 3], [4, 5, 6]) + if locals is None: + locals = {} + locals['fig'] = fig + + with patch("builtins.exec", side_effect=mock_exec_side_effect): + result = await generate_chart("question_intent", "data_provided", "question_params", llm, model) - with patch("matplotlib.pyplot.figure") as mock_fig: - mock_fig_instance = MagicMock() - mock_fig.return_value = mock_fig_instance - result = await generate_chart("question_intent", "data_provided", "question_params", self.llm, self.model) - buf = BytesIO() - mock_fig_instance.savefig.assert_called_once_with(buf, format="png") + response = json.loads(result) - self.llm.chat.assert_called_once_with( - self.model, self.generate_chart_code_prompt, self.details_to_generate_chart_code + image_data = response["content"] + decoded_image = base64.b64decode(image_data) + + image = Image.open(BytesIO(decoded_image)) + image.verify() + + llm.chat.assert_called_once_with( + model, + "generate chart code prompt", + "details to create chart code prompt" ) mock_sanitise_script.assert_called_once_with("generated code") - self.assertIsInstance(result, str) - @pytest.mark.asyncio - @patch("src.agents.chart_generator_agent.engine.load_prompt") - @patch("src.agents.chart_generator_agent.sanitise_script") - async def test_generate_code_no_figure(self, mock_sanitise_script, mock_load_prompt): - mock_load_prompt.side_effect = [self.details_to_generate_chart_code, self.generate_chart_code_prompt] - self.llm.chat.return_value = "generated code" - mock_sanitise_script.return_value = """ +@pytest.mark.asyncio +@patch("src.agents.chart_generator_agent.engine.load_prompt") +@patch("src.agents.chart_generator_agent.sanitise_script", new_callable=MagicMock) +async def test_generate_code_no_figure(mock_sanitise_script, mock_load_prompt): + llm = AsyncMock() + model = "mock_model" -import matplotlib.pyplot as plt -# No figure is created + mock_load_prompt.side_effect = [ + "details to create chart code prompt", + "generate chart code prompt" + ] + llm.chat.return_value = "generated code" + + return_string = mock_sanitise_script.return_value = """ +import matplotlib.pyplot as plt +# No fig creation """ - with self.assertRaises(ValueError) as context: - await generate_chart("question_intent", "data_provided", "question_params", self.llm, self.model) - self.assertEqual(str(context.exception), "The generated code did not produce a figure named 'fig'.") + plt.switch_backend('Agg') + + def mock_exec_side_effect(script, globals=None, locals=None): + if script == return_string: + if locals is None: + locals = {} + + with patch("builtins.exec", side_effect=mock_exec_side_effect): + with pytest.raises(ValueError, match="The generated code did not produce a figure named 'fig'."): + await generate_chart("question_intent", "data_provided", "question_params", llm, model) + + llm.chat.assert_called_once_with( + model, + "generate chart code prompt", + "details to create chart code prompt" + ) + + mock_sanitise_script.assert_called_once_with("generated code") +@pytest.mark.parametrize( + "input_script, expected_output", + [ -if __name__ == "__main__": - unittest.main() + ( + """```python +import matplotlib.pyplot as plt +fig = plt.figure() +plt.plot([1, 2, 3], [4, 5, 6]) +```""", + """import matplotlib.pyplot as plt +fig = plt.figure() +plt.plot([1, 2, 3], [4, 5, 6])""" + ), + ( + """```python +import matplotlib.pyplot as plt +fig = plt.figure() +plt.plot([1, 2, 3], [4, 5, 6])""", + """import matplotlib.pyplot as plt +fig = plt.figure() +plt.plot([1, 2, 3], [4, 5, 6])""" + ), + ( + """import matplotlib.pyplot as plt +fig = plt.figure() +plt.plot([1, 2, 3], [4, 5, 6]) +```""", + """import matplotlib.pyplot as plt +fig = plt.figure() +plt.plot([1, 2, 3], [4, 5, 6])""" + ), + ( + """import matplotlib.pyplot as plt +fig = plt.figure() +plt.plot([1, 2, 3], [4, 5, 6])""", + """import matplotlib.pyplot as plt +fig = plt.figure() +plt.plot([1, 2, 3], [4, 5, 6])""" + ), + ( + "", + "" + ) + ] +) +def test_sanitise_script(input_script, expected_output): + assert sanitise_script(input_script) == expected_output diff --git a/backend/tests/agents/datastore_agent_test.py b/backend/tests/agents/datastore_agent_test.py new file mode 100644 index 000000000..c6de422b7 --- /dev/null +++ b/backend/tests/agents/datastore_agent_test.py @@ -0,0 +1,90 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from src.agents.datastore_agent import generate_cypher_query_core + +@pytest.mark.asyncio +@patch("src.agents.datastore_agent.get_semantic_layer", new_callable=AsyncMock) +@patch("src.agents.datastore_agent.execute_query", new_callable=MagicMock) +@patch("src.agents.datastore_agent.publish_log_info", new_callable=AsyncMock) +@patch("src.agents.datastore_agent.engine.load_prompt", autospec=True) +async def test_generate_query_success(mock_load_prompt, mock_publish_log_info, + mock_execute_query, mock_get_semantic_layer): + llm = AsyncMock() + model = "mock_model" + + mock_load_prompt.side_effect = [ + "details to create cypher query prompt", + "generate cypher query prompt" + ] + + llm.chat.return_value = '{"query": "MATCH (n) RETURN n"}' + + mock_get_semantic_layer.return_value = {"nodes": [], "edges": []} + + mock_execute_query.return_value = "Mocked response from the database" + + question_intent = "Find all nodes" + operation = "MATCH" + question_params = "n" + aggregation = "none" + sort_order = "none" + timeframe = "2024" + model = "gpt-4" + + result = await generate_cypher_query_core(question_intent, operation, question_params, aggregation, sort_order, + timeframe, llm, model) + + assert result == '{\n "content": "Mocked response from the database",\n "ignore_validation": "false"\n}' + mock_load_prompt.assert_called() + llm.chat.assert_called_once_with( + model, + "generate cypher query prompt", + "details to create cypher query prompt", + return_json=True + ) + mock_execute_query.assert_called_once_with("MATCH (n) RETURN n") + mock_publish_log_info.assert_called() + +@pytest.mark.asyncio +@patch("src.agents.datastore_agent.get_semantic_layer", new_callable=AsyncMock) +@patch("src.agents.datastore_agent.execute_query", new_callable=MagicMock) +@patch("src.agents.datastore_agent.publish_log_info", new_callable=AsyncMock) +@patch("src.agents.datastore_agent.engine.load_prompt", autospec=True) +async def test_generate_query_failure(mock_load_prompt, mock_publish_log_info, + mock_execute_query, mock_get_semantic_layer): + llm = AsyncMock() + model = "mock_model" + + mock_load_prompt.side_effect = [ + "details to create cypher query prompt", + "generate cypher query prompt" + ] + + llm.chat.side_effect = Exception("LLM chat failed") + + mock_get_semantic_layer.return_value = {"nodes": [], "edges": []} + + question_intent = "Find all nodes" + operation = "MATCH" + question_params = "n" + aggregation = "none" + sort_order = "none" + timeframe = "2024" + model = "gpt-4" + + with pytest.raises(Exception, match="LLM chat failed"): + await generate_cypher_query_core(question_intent, operation, question_params, aggregation, sort_order, + timeframe, llm, model) + + mock_load_prompt.assert_called() + llm.chat.assert_called_once_with( + model, + "generate cypher query prompt", + "details to create cypher query prompt", + return_json=True + ) + mock_publish_log_info.assert_not_called() + mock_execute_query.assert_not_called() + +if __name__ == "__main__": + pytest.main(["-v"]) diff --git a/backend/tests/agents/file_agent_test.py b/backend/tests/agents/file_agent_test.py new file mode 100644 index 000000000..5d3f4e1dd --- /dev/null +++ b/backend/tests/agents/file_agent_test.py @@ -0,0 +1,53 @@ +import pytest +from unittest.mock import patch, mock_open +import json +import os +from src.agents.file_agent import read_file_core, write_or_update_file_core, create_response + +# Mocking config for the test +@pytest.fixture(autouse=True) +def mock_config(monkeypatch): + monkeypatch.setattr('src.agents.file_agent.config.files_directory', 'files') + +@pytest.mark.asyncio +@patch("builtins.open", new_callable=mock_open, read_data="Example file content.") +async def test_read_file_core_success(mock_file): + file_path = "example.txt" + result = await read_file_core(file_path) + expected_response = create_response("Example file content.") + assert json.loads(result) == json.loads(expected_response) + expected_full_path = os.path.normpath("/app/files/example.txt") + mock_file.assert_called_once_with(expected_full_path, 'r') + +@pytest.mark.asyncio +@patch("builtins.open", side_effect=FileNotFoundError) +async def test_read_file_core_file_not_found(mock_file): + file_path = "missing_file.txt" + result = await read_file_core(file_path) + expected_response = create_response(f"File {file_path} not found.", "error") + assert json.loads(result) == json.loads(expected_response) + expected_full_path = os.path.normpath("/app/files/missing_file.txt") + mock_file.assert_called_once_with(expected_full_path, 'r') + +@pytest.mark.asyncio +@patch("builtins.open", new_callable=mock_open) +async def test_write_file_core_success(mock_file): + file_path = "example_write.txt" + content = "This is test content to write." + result = await write_or_update_file_core(file_path, content, 'no') + expected_response = create_response(f"Content written to file {file_path}.") + assert json.loads(result) == json.loads(expected_response) + expected_full_path = os.path.normpath("/app/files/example_write.txt") + mock_file.assert_called_once_with(expected_full_path, 'w') + mock_file().write.assert_called_once_with(content) + +@pytest.mark.asyncio +@patch("builtins.open", side_effect=Exception("Unexpected error")) +async def test_write_file_core_error(mock_file): + file_path = "error_file.txt" + content = "Content with error." + result = await write_or_update_file_core(file_path, content, 'no') + expected_response = create_response(f"Error writing to file: {file_path}", "error") + assert json.loads(result) == json.loads(expected_response) + expected_full_path = os.path.normpath("/app/files/error_file.txt") + mock_file.assert_called_once_with(expected_full_path, 'w') diff --git a/backend/tests/agents/web_agent_test.py b/backend/tests/agents/web_agent_test.py index b50dc36ed..b863aaa8a 100644 --- a/backend/tests/agents/web_agent_test.py +++ b/backend/tests/agents/web_agent_test.py @@ -1,76 +1,83 @@ -import unittest -from unittest.mock import AsyncMock, patch - import pytest +from unittest.mock import patch, AsyncMock +import json from src.agents.web_agent import web_general_search_core +@pytest.mark.asyncio +@patch("src.agents.web_agent.answer_user_ques", new_callable=AsyncMock) +@patch("src.agents.web_agent.create_search_term", new_callable=AsyncMock) +@patch("src.agents.web_agent.perform_search", new_callable=AsyncMock) +@patch("src.agents.web_agent.perform_scrape", new_callable=AsyncMock) +@patch("src.agents.web_agent.perform_summarization", new_callable=AsyncMock) +@patch("src.agents.web_agent.is_valid_answer", new_callable=AsyncMock) +async def test_web_general_search_core( + mock_is_valid_answer, + mock_perform_summarization, + mock_perform_scrape, + mock_perform_search, + mock_create_search_term, + mock_answer_user_ques +): + llm = AsyncMock() + model = "mock_model" -class TestWebAgentCore(unittest.TestCase): - def setUp(self): - self.llm = AsyncMock() - self.model = "mock_model" - - @patch("src.agents.web_agent.perform_search") - @patch("src.agents.web_agent.perform_scrape") - @patch("src.agents.web_agent.perform_summarization") - @patch("src.agents.web_agent.is_valid_answer") - @pytest.mark.asyncio - async def test_web_general_search_core( - self, mock_is_valid_answer, mock_perform_summarization, mock_perform_scrape, mock_perform_search - ): - mock_perform_search.return_value = {"status": "success", "urls": ["http://example.com"]} - mock_perform_scrape.return_value = "Example scraped content." - mock_perform_summarization.return_value = "Example summary." - mock_is_valid_answer.return_value = True - - result = await web_general_search_core("example query", self.llm, self.model) - self.assertEqual(result, "Example summary.") - mock_perform_search.assert_called_once_with("example query", num_results=15) - mock_perform_scrape.assert_called_once_with("http://example.com") - mock_perform_summarization.assert_called_once_with( - "example query", "Example scraped content.", self.llm, self.model - ) - mock_is_valid_answer.assert_called_once_with("Example summary.", "example query") + # Mocking answer_user_ques to return a valid answer + mock_answer_user_ques.return_value = json.dumps({ + "status": "success", + "response": json.dumps({"is_valid": True, "answer": "Example summary."}) + }) - @patch("src.agents.web_agent.perform_search") - @patch("src.agents.web_agent.perform_scrape") - @patch("src.agents.web_agent.perform_summarization") - @patch("src.agents.web_agent.is_valid_answer") - @pytest.mark.asyncio - async def test_web_general_search_core_no_results( - self, mock_is_valid_answer, mock_perform_summarization, mock_perform_scrape, mock_perform_search - ): - mock_perform_search.return_value = {"status": "error", "urls": []} + result = await web_general_search_core("example query", llm, model) + expected_response = { + "content": "Example summary.", + "ignore_validation": "false" + } + assert json.loads(result) == expected_response - result = await web_general_search_core("example query", self.llm, self.model) - self.assertEqual(result, "No relevant information found on the internet for the given query.") - mock_perform_search.assert_called_once_with("example query", num_results=15) - mock_perform_scrape.assert_not_called() - mock_perform_summarization.assert_not_called() - mock_is_valid_answer.assert_not_called() - @patch("src.agents.web_agent.perform_search") - @patch("src.agents.web_agent.perform_scrape") - @patch("src.agents.web_agent.perform_summarization") - @patch("src.agents.web_agent.is_valid_answer") - @pytest.mark.asyncio - async def test_web_general_search_core_invalid_summary( - self, mock_is_valid_answer, mock_perform_summarization, mock_perform_scrape, mock_perform_search - ): - mock_perform_search.return_value = {"status": "success", "urls": ["http://example.com"]} - mock_perform_scrape.return_value = "Example scraped content." - mock_perform_summarization.return_value = "Example invalid summary." - mock_is_valid_answer.return_value = False +@pytest.mark.asyncio +@patch("src.agents.web_agent.perform_search", new_callable=AsyncMock) +@patch("src.agents.web_agent.perform_scrape", new_callable=AsyncMock) +@patch("src.agents.web_agent.perform_summarization", new_callable=AsyncMock) +@patch("src.agents.web_agent.is_valid_answer", new_callable=AsyncMock) +async def test_web_general_search_core_no_results( + mock_is_valid_answer, + mock_perform_summarization, + mock_perform_scrape, + mock_perform_search, +): + llm = AsyncMock() + model = "mock_model" + mock_perform_search.return_value = {"status": "error", "urls": []} + result = await web_general_search_core("example query", llm, model) - result = await web_general_search_core("example query", self.llm, self.model) - self.assertEqual(result, "No relevant information found on the internet for the given query.") - mock_perform_search.assert_called_once_with("example query", num_results=15) - mock_perform_scrape.assert_called_once_with("http://example.com") - mock_perform_summarization.assert_called_once_with( - "example query", "Example scraped content.", self.llm, self.model - ) - mock_is_valid_answer.assert_called_once_with("Example invalid summary.", "example query") + expected_response = { + "content": "Error in finding the answer.", + "ignore_validation": "false" + } + assert json.loads(result) == expected_response +@pytest.mark.asyncio +@patch("src.agents.web_agent.perform_search", new_callable=AsyncMock) +@patch("src.agents.web_agent.perform_scrape", new_callable=AsyncMock) +@patch("src.agents.web_agent.perform_summarization", new_callable=AsyncMock) +@patch("src.agents.web_agent.is_valid_answer", new_callable=AsyncMock) +async def test_web_general_search_core_invalid_summary( + mock_is_valid_answer, + mock_perform_summarization, + mock_perform_scrape, + mock_perform_search +): + llm = AsyncMock() + model = "mock_model" + mock_perform_search.return_value = {"status": "success", "urls": ["http://example.com"]} + mock_perform_scrape.return_value = "Example scraped content." + mock_perform_summarization.return_value = json.dumps({"summary": "Example invalid summary."}) + mock_is_valid_answer.return_value = False + result = await web_general_search_core("example query", llm, model) + expected_response = { + "content": "Error in finding the answer.", + "ignore_validation": "false" + } + assert json.loads(result) == expected_response -if __name__ == "__main__": - unittest.main() diff --git a/backend/tests/api/message_handlers_test.py b/backend/tests/api/message_handlers_test.py deleted file mode 100644 index 42040c591..000000000 --- a/backend/tests/api/message_handlers_test.py +++ /dev/null @@ -1,26 +0,0 @@ -from unittest.mock import call -import pytest -from src.websockets.message_handlers import create_on_ping, pong - - -def test_on_ping_send_pong(mocker): - on_ping = create_on_ping() - mock_ws = mocker.Mock() - mock_disconnect = mocker.AsyncMock() - mocked_create_task = mocker.patch("asyncio.create_task") - - on_ping(mock_ws, mock_disconnect, None) - - first_call = mocked_create_task.call_args_list[0] - assert first_call == call(mock_ws.send_json(pong)) - - -@pytest.mark.asyncio -async def test_on_ping_no_disconnect(mocker): - on_ping = create_on_ping() - mock_ws = mocker.AsyncMock() - mock_disconnect = mocker.AsyncMock() - - on_ping(mock_ws, mock_disconnect, None) - - mock_disconnect.assert_not_awaited() diff --git a/backend/tests/llm/mistral_test.py b/backend/tests/llm/mistral_test.py index 86f87f7b2..f92a5ee86 100644 --- a/backend/tests/llm/mistral_test.py +++ b/backend/tests/llm/mistral_test.py @@ -1,10 +1,8 @@ +import logging from typing import cast -from unittest.mock import MagicMock -from mistralai.async_client import MistralAsyncClient -from mistralai.models.chat_completion import ChatCompletionResponse -from mistralai.models.chat_completion import ChatCompletionResponseChoice -from mistralai.models.chat_completion import ChatMessage -from mistralai.models.common import UsageInfo +from unittest.mock import AsyncMock, MagicMock +from mistralai import UNSET, AssistantMessage, Mistral as MistralApi, SystemMessage, UserMessage +from mistralai.models import ChatCompletionResponse, ChatCompletionChoice, UsageInfo import pytest from src.llm import get_llm, Mistral from src.utils import Config @@ -17,34 +15,23 @@ mistral = cast(Mistral, get_llm("mistral")) -async def create_mock_chat_response(content, tool_calls=None): +def create_mock_chat_response(content, tool_calls=None): mock_usage = UsageInfo(prompt_tokens=1, total_tokens=2, completion_tokens=3) - mock_message = ChatMessage(role="system", content=content, tool_calls=tool_calls) - mock_choice = ChatCompletionResponseChoice(index=0, message=mock_message, finish_reason=None) + mock_message = AssistantMessage(content=content, tool_calls=tool_calls) + mock_choice = ChatCompletionChoice(index=0, message=mock_message, finish_reason="stop") return ChatCompletionResponse( id="id", object="object", created=123, model="model", choices=[mock_choice], usage=mock_usage ) -mock_client = MagicMock(spec=MistralAsyncClient) +mock_client = AsyncMock(spec=MistralApi) mock_config = MagicMock(spec=Config) @pytest.mark.asyncio async def test_chat_content_string_returns_string(mocker): - mistral.client = mocker.MagicMock(return_value=mock_client) - mistral.client.chat.return_value = create_mock_chat_response(content_response) - - response = await mistral.chat(mock_model, system_prompt, user_prompt) - - assert response == content_response - - -@pytest.mark.asyncio -async def test_chat_content_list_returns_string(mocker): - content_list = ["Hello", "there"] - mistral.client = mocker.MagicMock(return_value=mock_client) - mistral.client.chat.return_value = create_mock_chat_response(content_list) + mistral.client = mocker.AsyncMock(return_value=mock_client) + mistral.client.chat.complete_async.return_value = create_mock_chat_response(content_response) response = await mistral.chat(mock_model, system_prompt, user_prompt) @@ -58,9 +45,94 @@ async def test_chat_calls_client_chat(mocker): await mistral.chat(mock_model, system_prompt, user_prompt) expected_messages = [ - ChatMessage(role="system", content=system_prompt), - ChatMessage(role="user", content=user_prompt), + SystemMessage(content=system_prompt), + UserMessage(content=user_prompt), ] - mistral.client.chat.assert_called_once_with( + mistral.client.chat.complete_async.assert_awaited_once_with( messages=expected_messages, model=mock_model, temperature=0, response_format=None ) + + +@pytest.mark.asyncio +async def test_chat_response_none_logs_error(mocker, caplog): + mistral.client = mocker.AsyncMock(return_value=mock_client) + mistral.client.chat.complete_async.return_value = None + + response = await mistral.chat(mock_model, system_prompt, user_prompt) + + assert response == "An error occurred while processing the request." + assert ( + "src.llm.mistral", + logging.ERROR, + "Call to Mistral API failed: No valid response or choices received", + ) in caplog.record_tuples + + +@pytest.mark.asyncio +async def test_chat_response_choices_none_logs_error(mocker, caplog): + mistral.client = mocker.AsyncMock(return_value=mock_client) + chat_response = create_mock_chat_response(content_response) + chat_response.choices = None + mistral.client.chat.complete_async.return_value = chat_response + + response = await mistral.chat(mock_model, system_prompt, user_prompt) + + assert response == "An error occurred while processing the request." + assert ( + "src.llm.mistral", + logging.ERROR, + "Call to Mistral API failed: No valid response or choices received", + ) in caplog.record_tuples + + +@pytest.mark.asyncio +async def test_chat_response_choices_empty_logs_error(mocker, caplog): + mistral.client = mocker.AsyncMock(return_value=mock_client) + chat_response = create_mock_chat_response(content_response) + chat_response.choices = [] + mistral.client.chat.complete_async.return_value = chat_response + + response = await mistral.chat(mock_model, system_prompt, user_prompt) + + assert response == "An error occurred while processing the request." + assert ( + "src.llm.mistral", + logging.ERROR, + "Call to Mistral API failed: No valid response or choices received", + ) in caplog.record_tuples + + +@pytest.mark.asyncio +async def test_chat_response_choices_message_content_none_logs_error(mocker, caplog): + mistral.client = mocker.AsyncMock(return_value=mock_client) + chat_response = create_mock_chat_response(content_response) + assert chat_response.choices is not None + chat_response.choices[0].message.content = None + mistral.client.chat.complete_async.return_value = chat_response + + response = await mistral.chat(mock_model, system_prompt, user_prompt) + + assert response == "An error occurred while processing the request." + assert ( + "src.llm.mistral", + logging.ERROR, + "Call to Mistral API failed: message content is None or Unset", + ) in caplog.record_tuples + + +@pytest.mark.asyncio +async def test_chat_response_choices_message_content_unset_logs_error(mocker, caplog): + mistral.client = mocker.AsyncMock(return_value=mock_client) + chat_response = create_mock_chat_response(content_response) + assert chat_response.choices is not None + chat_response.choices[0].message.content = UNSET + mistral.client.chat.complete_async.return_value = chat_response + + response = await mistral.chat(mock_model, system_prompt, user_prompt) + + assert response == "An error occurred while processing the request." + assert ( + "src.llm.mistral", + logging.ERROR, + "Call to Mistral API failed: message content is None or Unset", + ) in caplog.record_tuples diff --git a/backend/tests/prompts/prompting_test.py b/backend/tests/prompts/prompting_test.py index 66907e0f1..027574cab 100644 --- a/backend/tests/prompts/prompting_test.py +++ b/backend/tests/prompts/prompting_test.py @@ -58,11 +58,12 @@ def test_load_best_next_step_template(): AGENT LIST: -If the list of agents does not contain something suitable, you should say the agent is 'none'. ie. If question is 'general knowledge', 'personal' or a 'greeting'. +If the list of agents does not contain something suitable, you should say the agent is 'WebAgent'. ie. If question is 'general knowledge', 'personal' or a 'greeting'. ## Determine the next best step Your task is to pick one of the mentioned agents above to complete the task. If the same agent_name and task are repeated more than twice in the history, you must not pick that agent_name. +If mathematical processing (e.g., rounding or calculations) is needed, choose the MathsAgent. If file operations are needed, choose the FileAgent. Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications. @@ -103,11 +104,12 @@ def test_load_best_next_step_with_history_template(): AGENT LIST: -If the list of agents does not contain something suitable, you should say the agent is 'none'. ie. If question is 'general knowledge', 'personal' or a 'greeting'. +If the list of agents does not contain something suitable, you should say the agent is 'WebAgent'. ie. If question is 'general knowledge', 'personal' or a 'greeting'. ## Determine the next best step Your task is to pick one of the mentioned agents above to complete the task. If the same agent_name and task are repeated more than twice in the history, you must not pick that agent_name. +If mathematical processing (e.g., rounding or calculations) is needed, choose the MathsAgent. If file operations are needed, choose the FileAgent. Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications. @@ -136,11 +138,16 @@ def test_best_tool_template(): Pick 1 tool (no more than 1) from the list below to complete this task. Fit the correct parameters from the task to the tool arguments. +Ensure that numerical values are formatted correctly, including the use of currency symbols (e.g., "£") and units of measurement (e.g., "million") if applicable. Parameters with required as False do not need to be fit. Add if appropriate, but do not hallucinate arguments for these parameters {"description": "mock desc", "name": "say hello world", "parameters": {"name": {"type": "string", "description": "name of user"}}} +Important: +If the task involves financial data, ensure that all monetary values are expressed with appropriate currency (e.g., "£") and rounded to the nearest million if specified. +If the task involves scaling (e.g., thousands, millions), ensure that the extracted parameters reflect the appropriate scale (e.g., "£15 million", "£5000"). + From the task you should be able to extract the parameters. If it is data driven, it should be turned into a cypher query If none of the tools are appropriate for the task, return the following tool @@ -161,7 +168,8 @@ def test_best_tool_template(): def test_tool_selection_format_template(): engine = PromptEngine() try: - expected_string = """Reply only in json with the following format: + expected_string = """Reply only in json with the following format, in the tool_parameters please include the currency and measuring scale used in the content provided.: + { \"tool_name\": \"the exact string name of the tool chosen\", diff --git a/backend/tests/supervisors/supervisor_test.py b/backend/tests/supervisors/supervisor_test.py index 31b5c42d5..31c0a4147 100644 --- a/backend/tests/supervisors/supervisor_test.py +++ b/backend/tests/supervisors/supervisor_test.py @@ -64,7 +64,7 @@ async def test_solve_task_first_attempt_solves(mocker): mock_answer_json = json.loads(mock_answer) # Ensure that the result is returned directly without validation - assert answer == (agent.name, mock_answer_json.get('content', '')) + assert answer == (agent.name, mock_answer_json.get('content', ''), "success") @pytest.mark.asyncio @@ -83,7 +83,7 @@ async def test_solve_task_ignore_validation(mocker): mock_answer_json = json.loads(mock_answer) # Ensure that the result is returned directly without validation - assert answer == (agent.name, mock_answer_json.get('content', '')) + assert answer == (agent.name, mock_answer_json.get('content', ''), "success") mock_is_valid_answer.assert_not_called() # Validation should not be called @pytest.mark.asyncio diff --git a/backend/tests/websockets/confirmations_manager_test.py b/backend/tests/websockets/confirmations_manager_test.py new file mode 100644 index 000000000..85151831b --- /dev/null +++ b/backend/tests/websockets/confirmations_manager_test.py @@ -0,0 +1,86 @@ +from uuid import uuid4 + +import pytest +from src.websockets.confirmations_manager import ConfirmationsManager + + +class TestConfirmationsManager: + def test_add_confirmation(self): + # Arrange + manager = ConfirmationsManager() + confirmation_id = uuid4() + + # Act + manager.add_confirmation(confirmation_id) + + # Assert + confirmation = manager.get_confirmation_state(confirmation_id) + assert confirmation is None + + def test_get_confirmation_state_not_found_id(self): + # Arrange + manager = ConfirmationsManager() + not_found_confirmation_id = uuid4() + + # Act + with pytest.raises(Exception) as e: + manager.get_confirmation_state(not_found_confirmation_id) + + # Assert + assert str(e.value) == f"Cannot get confirmation. Confirmation with id '{not_found_confirmation_id}' not found" + + @pytest.mark.parametrize("input_value", [True, False]) + def test_update_confirmation(self, input_value): + # Arrange + manager = ConfirmationsManager() + confirmation_id = uuid4() + manager.add_confirmation(confirmation_id) + + # Act + manager.update_confirmation(confirmation_id, input_value) + + # Assert + updated_value = manager.get_confirmation_state(confirmation_id) + assert updated_value == input_value + + def test_update_confirmation_not_found_id(self): + # Arrange + manager = ConfirmationsManager() + not_found_confirmation_id = uuid4() + + # Act + with pytest.raises(Exception) as e: + manager.update_confirmation(not_found_confirmation_id, True) + + # Assert + assert ( + str(e.value) == f"Cannot update confirmation. Confirmation with id '{not_found_confirmation_id}' not found" + ) + + def test_delete_confirmation(self): + # Arrange + manager = ConfirmationsManager() + confirmation_id = uuid4() + manager.add_confirmation(confirmation_id) + + # Act + manager.delete_confirmation(confirmation_id) + + # Assert + with pytest.raises(Exception) as e: + manager.get_confirmation_state(confirmation_id) + assert "Cannot get confirmation." in str(e.value) + + def test_delete_confirmation_not_found_id(self): + # Arrange + manager = ConfirmationsManager() + not_found_confirmation_id = uuid4() + + # Act + with pytest.raises(Exception) as e: + manager.delete_confirmation(not_found_confirmation_id) + + # Assert + assert ( + str(e.value) == f"Cannot delete confirmation. Confirmation with id '{not_found_confirmation_id}' not found" + ) diff --git a/backend/tests/api/connection_manager_test.py b/backend/tests/websockets/connection_manager_test.py similarity index 97% rename from backend/tests/api/connection_manager_test.py rename to backend/tests/websockets/connection_manager_test.py index 087941382..809742314 100644 --- a/backend/tests/api/connection_manager_test.py +++ b/backend/tests/websockets/connection_manager_test.py @@ -1,4 +1,3 @@ -import json from unittest.mock import patch import pytest from fastapi.websockets import WebSocketState @@ -101,6 +100,7 @@ async def test_disconnect_websocket_already_disconnected(connection_manager): mock_ws.close.assert_not_called() assert len(manager.websockets) == 0 + @pytest.mark.asyncio async def test_handle_message_handler_exists_for_message_type_handler_called(connection_manager, mocker): manager, mock_ws, _ = connection_manager @@ -113,7 +113,6 @@ async def test_handle_message_handler_exists_for_message_type_handler_called(con handler.assert_called() - @pytest.mark.asyncio async def test_handle_message_handler_does_not_exist_for_message_type_handler_called(connection_manager): manager, mock_ws, _ = connection_manager @@ -126,7 +125,6 @@ async def test_handle_message_handler_does_not_exist_for_message_type_handler_ca assert str(error.value) == "No handler for message type" - @pytest.mark.asyncio async def test_broadcast_given_message_broadcasted(connection_manager): manager, mock_ws, _ = connection_manager @@ -136,7 +134,7 @@ async def test_broadcast_given_message_broadcasted(connection_manager): await manager.broadcast(message) - mock_ws.send_json.assert_awaited_once_with(json.dumps({"type": message.type.value, "data": message.data})) + mock_ws.send_json.assert_awaited_once_with({"type": message.type.value, "data": message.data}) @pytest.mark.asyncio diff --git a/backend/tests/websockets/message_handlers_test.py b/backend/tests/websockets/message_handlers_test.py new file mode 100644 index 000000000..baf13f478 --- /dev/null +++ b/backend/tests/websockets/message_handlers_test.py @@ -0,0 +1,134 @@ +import logging +from unittest.mock import Mock, call, patch +from uuid import uuid4 +import pytest +from src.websockets.message_handlers import create_on_ping, on_confirmation, pong + + +def test_on_ping_send_pong(mocker): + on_ping = create_on_ping() + mock_ws = mocker.Mock() + mock_disconnect = mocker.AsyncMock() + mocked_create_task = mocker.patch("asyncio.create_task") + + on_ping(mock_ws, mock_disconnect, None) + + first_call = mocked_create_task.call_args_list[0] + assert first_call == call(mock_ws.send_json(pong)) + + +@pytest.mark.asyncio +async def test_on_ping_no_disconnect(mocker): + on_ping = create_on_ping() + mock_ws = mocker.AsyncMock() + mock_disconnect = mocker.AsyncMock() + + on_ping(mock_ws, mock_disconnect, None) + + mock_disconnect.assert_not_awaited() + + +@pytest.mark.parametrize("input_value,expected_bool", [("y", True), ("n", False)]) +@patch("src.websockets.message_handlers.confirmations_manager") +def test_on_confirmation(confirmations_manager_mock, input_value, expected_bool): + # Arrange + confirmation_id = uuid4() + data = f"{confirmation_id}:{input_value}" + websocket_mock = Mock() + disconnect_mock = Mock() + + # Act + on_confirmation(websocket_mock, disconnect_mock, data) + + # Assert + confirmations_manager_mock.update_confirmation.assert_called_once_with(confirmation_id, expected_bool) + + +@patch("src.websockets.message_handlers.confirmations_manager") +def test_on_confirmation_data_is_none(confirmations_manager_mock, caplog): + # Arrange + websocket_mock = Mock() + disconnect_mock = Mock() + + # Act + on_confirmation(websocket_mock, disconnect_mock, None) + + # Assert + confirmations_manager_mock.update_confirmation.assert_not_called() + assert ( + "src.websockets.message_handlers", + logging.WARNING, + "Confirmation response did not include data", + ) in caplog.record_tuples + + +@patch("src.websockets.message_handlers.confirmations_manager") +def test_on_confirmation_seperator_not_present(confirmations_manager_mock, caplog): + # Arrange + websocket_mock = Mock() + disconnect_mock = Mock() + data = "abc" + + # Act + on_confirmation(websocket_mock, disconnect_mock, data) + + # Assert + confirmations_manager_mock.update_confirmation.assert_not_called() + assert ( + "src.websockets.message_handlers", + logging.WARNING, + "Seperator (':') not present in confirmation", + ) in caplog.record_tuples + + +@patch("src.websockets.message_handlers.confirmations_manager") +def test_on_confirmation_seperator_id_not_uuid(confirmations_manager_mock, caplog): + # Arrange + websocket_mock = Mock() + disconnect_mock = Mock() + data = "abc:y" + + # Act + on_confirmation(websocket_mock, disconnect_mock, data) + + # Assert + confirmations_manager_mock.update_confirmation.assert_not_called() + assert ("src.websockets.message_handlers", logging.WARNING, "Received invalid id") in caplog.record_tuples + + +@pytest.mark.parametrize("input_value", [(""), ("abc")]) +@patch("src.websockets.message_handlers.confirmations_manager") +def test_on_confirmation_value_not_valid(confirmations_manager_mock, caplog, input_value): + # Arrange + websocket_mock = Mock() + disconnect_mock = Mock() + confirmation_id = uuid4() + data = f"{confirmation_id}:{input_value}" + + # Act + on_confirmation(websocket_mock, disconnect_mock, data) + + # Assert + confirmations_manager_mock.update_confirmation.assert_not_called() + assert ("src.websockets.message_handlers", logging.WARNING, "Received invalid value") in caplog.record_tuples + + +@patch("src.websockets.message_handlers.confirmations_manager") +def test_on_confirmation_confirmation_manager_raises_exception(confirmations_manager_mock, caplog): + # Arrange + confirmation_id = uuid4() + data = f"{confirmation_id}:y" + websocket_mock = Mock() + disconnect_mock = Mock() + exception_message = "Test Exception Message" + confirmations_manager_mock.update_confirmation.side_effect = Exception(exception_message) + + # Act + on_confirmation(websocket_mock, disconnect_mock, data) + + # Assert + assert ( + "src.websockets.message_handlers", + logging.WARNING, + f"Could not update confirmation: '{exception_message}'", + ) in caplog.record_tuples diff --git a/backend/tests/websockets/user_confirmer_test.py b/backend/tests/websockets/user_confirmer_test.py new file mode 100644 index 000000000..896ee17af --- /dev/null +++ b/backend/tests/websockets/user_confirmer_test.py @@ -0,0 +1,73 @@ +import logging +from unittest.mock import Mock, patch + +import pytest + +from src.websockets.types import Message, MessageTypes +from src.websockets.user_confirmer import UserConfirmer +from src.websockets.confirmations_manager import ConfirmationsManager + + +class TestUserConfirmer: + @pytest.mark.asyncio + async def test_confirm_times_out(self, caplog): + # Arrange + confirmations_manager_mock = Mock(spec=ConfirmationsManager) + confirmations_manager_mock.get_confirmation_state.return_value = None + user_confirmer = UserConfirmer(confirmations_manager_mock) + user_confirmer._TIMEOUT_SECONDS = 0.05 + user_confirmer._POLL_RATE_SECONDS = 0.01 + + # Act + result = await user_confirmer.confirm("Test Message") + + # Assert + assert result is False + confirmations_manager_mock.add_confirmation.assert_called_once() + id = confirmations_manager_mock.add_confirmation.call_args.args[0] + confirmations_manager_mock.delete_confirmation.assert_called_once_with(id) + assert caplog.record_tuples == [ + ("src.websockets.user_confirmer", logging.WARNING, f"Confirmation with id {id} timed out.") + ] + + @pytest.mark.asyncio + @patch("src.websockets.connection_manager.connection_manager") + async def test_confirm_approved(self, connection_manager_mock): + # Arrange + confirmations_manager_mock = Mock(spec=ConfirmationsManager) + confirmations_manager_mock.get_confirmation_state.side_effect = [None, True] + user_confirmer = UserConfirmer(confirmations_manager_mock) + user_confirmer._POLL_RATE_SECONDS = 0.01 + + # Act + result = await user_confirmer.confirm("Test Message") + + # Assert + assert result is True + confirmations_manager_mock.add_confirmation.assert_called_once() + id = confirmations_manager_mock.add_confirmation.call_args.args[0] + connection_manager_mock.broadcast.awaited_once_with(Message(MessageTypes.CONFIRMATION, f"{id}:Test Message")) + confirmations_manager_mock.get_confirmation_state.assert_called_with(id) + assert confirmations_manager_mock.get_confirmation_state.call_count == 2 + confirmations_manager_mock.delete_confirmation.assert_called_once_with(id) + + @pytest.mark.asyncio + @patch("src.websockets.connection_manager.connection_manager") + async def test_confirm_denied(self, connection_manager_mock): + # Arrange + confirmations_manager_mock = Mock(spec=ConfirmationsManager) + confirmations_manager_mock.get_confirmation_state.side_effect = [None, False] + user_confirmer = UserConfirmer(confirmations_manager_mock) + user_confirmer._POLL_RATE_SECONDS = 0.01 + + # Act + result = await user_confirmer.confirm("Test Message") + + # Assert + assert result is False + confirmations_manager_mock.add_confirmation.assert_called_once() + id = confirmations_manager_mock.add_confirmation.call_args.args[0] + connection_manager_mock.broadcast.awaited_once_with(Message(MessageTypes.CONFIRMATION, f"{id}:Test Message")) + confirmations_manager_mock.get_confirmation_state.assert_called_with(id) + assert confirmations_manager_mock.get_confirmation_state.call_count == 2 + confirmations_manager_mock.delete_confirmation.assert_called_once_with(id) diff --git a/compose.yml b/compose.yml index b0b8fd3fb..4a979141f 100644 --- a/compose.yml +++ b/compose.yml @@ -40,10 +40,14 @@ services: start_period: 60s # InferGPT Backend backend: + env_file: + - .env image: infergpt/backend build: context: backend dockerfile: ./Dockerfile + volumes: + - ./${FILES_DIRECTORY}:/app/${FILES_DIRECTORY} environment: NEO4J_URI: bolt://neo4j-db:7687 NEO4J_USERNAME: ${NEO4J_USERNAME} @@ -51,6 +55,7 @@ services: MISTRAL_KEY: ${MISTRAL_KEY} OPENAI_KEY: ${OPENAI_KEY} FRONTEND_URL: ${FRONTEND_URL} + FILES_DIRECTORY: ${FILES_DIRECTORY} AZURE_STORAGE_CONNECTION_STRING: ${AZURE_STORAGE_CONNECTION_STRING} AZURE_STORAGE_CONTAINER_NAME: ${AZURE_STORAGE_CONTAINER_NAME} AZURE_INITIAL_DATA_FILENAME: ${AZURE_INITIAL_DATA_FILENAME} @@ -62,6 +67,7 @@ services: MATHS_AGENT_LLM: ${MATHS_AGENT_LLM} ROUTER_LLM: ${ROUTER_LLM} CHART_GENERATOR_LLM: ${CHART_GENERATOR_LLM} + FILE_AGENT_LLM: ${FILE_AGENT_LLM} ANSWER_AGENT_MODEL: ${ANSWER_AGENT_MODEL} INTENT_AGENT_MODEL: ${INTENT_AGENT_MODEL} VALIDATOR_AGENT_MODEL: ${VALIDATOR_AGENT_MODEL} @@ -71,6 +77,7 @@ services: MATHS_AGENT_MODEL: ${MATHS_AGENT_MODEL} AGENT_CLASS_MODEL: ${AGENT_CLASS_MODEL} CHART_GENERATOR_MODEL: ${CHART_GENERATOR_MODEL} + FILE_AGENT_MODEL: ${FILE_AGENT_MODEL} depends_on: neo4j-db: condition: service_healthy diff --git a/frontend/src/components/chat.tsx b/frontend/src/components/chat.tsx index 9c5fe9256..31bfec5bb 100644 --- a/frontend/src/components/chat.tsx +++ b/frontend/src/components/chat.tsx @@ -3,23 +3,37 @@ import { Message, MessageComponent } from './message'; import styles from './chat.module.css'; import { Waiting } from './waiting'; import { ConnectionStatus } from './connection-status'; -import { WebsocketContext, MessageType } from '../session/websocket-context'; - +import { WebsocketContext, MessageType, Message as wsMessage } from '../session/websocket-context'; +import { Confirmation, ConfirmModal } from './confirm-modal'; export interface ChatProps { messages: Message[]; waiting: boolean; } +const mapWsMessageToConfirmation = (message: wsMessage): Confirmation | undefined => { + if (!message.data) { + return; + } + const parts = message.data.split(':'); + return { id: parts[0], requestMessage: parts[1], result: null }; +}; + export const Chat = ({ messages, waiting }: ChatProps) => { const containerRef = React.useRef(null); - const { isConnected, lastMessage } = useContext(WebsocketContext); + const { isConnected, lastMessage, send } = useContext(WebsocketContext); const [chart, setChart] = useState(undefined); + const [confirmation, setConfirmation] = useState(null); useEffect(() => { if (lastMessage && lastMessage.type === MessageType.IMAGE) { const imageData = `data:image/png;base64,${lastMessage.data}`; setChart(imageData); } + if (lastMessage && lastMessage.type === MessageType.CONFIRMATION) { + const newConfirmation = mapWsMessageToConfirmation(lastMessage); + if (newConfirmation) + setConfirmation(newConfirmation); + } }, [lastMessage]); useEffect(() => { @@ -29,13 +43,16 @@ export const Chat = ({ messages, waiting }: ChatProps) => { }, [messages.length]); return ( -
- - {messages.map((message, index) => ( - - ))} - {chart && Generated chart} - {waiting && } -
+ <> + +
+ + {messages.map((message, index) => ( + + ))} + {chart && Generated chart} + {waiting && } +
+ ); }; diff --git a/frontend/src/components/confirm-modal.module.css b/frontend/src/components/confirm-modal.module.css new file mode 100644 index 000000000..9004e1e33 --- /dev/null +++ b/frontend/src/components/confirm-modal.module.css @@ -0,0 +1,63 @@ +.modal{ + width: 40%; + height: 40%; + background-color: #4c4c4c; + color: var(--text-color-primary); + border: 2px black; + border-radius: 10px; +} + +.modalContent{ + width: 100%; + height: 100%; + display: flex; + flex-direction: column; +} + +.header{ + text-align: center; +} + +.modal::backdrop{ + background: rgb(0,0,0,0.8); +} + +.requestMessage{ + flex-grow: 1; +} + +.buttonsBar{ + display: flex; + gap: 0.5rem; +} + +.button{ + color: var(--text-color-primary); + font-weight: bold; + border: none; + width: 100%; + padding: 1rem; + cursor: pointer; + border-radius: 3px; +} + + +.cancel{ + composes: button; + background-color: var(--background-color-primary); +} + +.cancel:hover{ + background-color: #141414; + transition: all 0.5s; +} + +.confirm{ + composes: button; + background-color: var(--blue); +} + +.confirm:hover{ + background-color: #146AFF; + transition: all 0.5s; +} diff --git a/frontend/src/components/confirm-modal.tsx b/frontend/src/components/confirm-modal.tsx new file mode 100644 index 000000000..9cadf0fc2 --- /dev/null +++ b/frontend/src/components/confirm-modal.tsx @@ -0,0 +1,57 @@ +import Styles from './confirm-modal.module.css'; +import { useEffect, useRef } from 'react'; +import { Message, MessageType } from '../session/websocket-context'; +import React from 'react'; + +export interface Confirmation { + id: string, + requestMessage: string, + result: boolean | null +} + +interface ConfirmModalProps { + confirmation: Confirmation | null, + setConfirmation: (confirmation: Confirmation | null) => void, + send: (message: Message) => void +} + +export const ConfirmModal = ({ confirmation, setConfirmation, send }: ConfirmModalProps) => { + const mapConfirmationToMessage = (confirmation: Confirmation): Message => { + return { type: MessageType.CONFIRMATION, data: confirmation.id + ':' + (confirmation.result ? 'y' : 'n') }; + }; + + const updateConfirmationResult = (newResult: boolean) => { + if (confirmation) { + setConfirmation({ ...confirmation, result: newResult }); + } + }; + + + const modalRef = useRef(null); + + useEffect(() => { + if (confirmation) { + if (confirmation.result !== null) { + send(mapConfirmationToMessage(confirmation)); + setConfirmation(null); + } else { + modalRef.current?.showModal(); + } + } else { + modalRef.current?.close(); + } + }, [confirmation]); + + return ( + updateConfirmationResult(false)}> +
+

Confirmation

+

{confirmation && confirmation.requestMessage}

+
+ + +
+
+
+ ); +}; diff --git a/frontend/src/session/websocket-context.ts b/frontend/src/session/websocket-context.ts index 0c56803e6..a684bc0d5 100644 --- a/frontend/src/session/websocket-context.ts +++ b/frontend/src/session/websocket-context.ts @@ -3,7 +3,8 @@ import { createContext } from 'react'; export enum MessageType { PING = 'ping', CHAT = 'chat', - IMAGE = 'image' + IMAGE = 'image', + CONFIRMATION = 'confirmation', } export interface Message { diff --git a/frontend/webpack.config.js b/frontend/webpack.config.js index 3fcd4a0a4..ecfc2c7f8 100644 --- a/frontend/webpack.config.js +++ b/frontend/webpack.config.js @@ -7,14 +7,17 @@ import { fileURLToPath } from 'url'; const __dirname = path.dirname(fileURLToPath(import.meta.url)); const localEnv = dotenv.config({ path: path.resolve(__dirname, '../.env') }).parsed; -const env = { ...process.env, ...localEnv }; +const env = { ...process.env, ...localEnv }; const config = { mode: 'development', entry: './src/index.tsx', output: { - path: __dirname + '/dist/', + path: path.resolve(__dirname, 'dist'), + publicPath: '/', + filename: '[name].bundle.js' }, + module: { rules: [ {