Skip to content

Commit

Permalink
Merge pull request WaitThatShouldntWork#98 from WaitThatShouldntWork/…
Browse files Browse the repository at this point in the history
…release/improve-reliability

Release/improve reliability
  • Loading branch information
hsauve-scottlogic authored Oct 18, 2024
2 parents 6f236b2 + 70524a8 commit 72e117e
Show file tree
Hide file tree
Showing 56 changed files with 1,922 additions and 557 deletions.
5 changes: 5 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ NEO4J_URI=bolt://localhost:7687
NEO4J_HTTP_PORT=7474
NEO4J_BOLT_PORT=7687

# files location
FILES_DIRECTORY=files

# backend LLM properties
MISTRAL_KEY=my-api-key

Expand Down Expand Up @@ -42,6 +45,7 @@ MATHS_AGENT_LLM="openai"
WEB_AGENT_LLM="openai"
CHART_GENERATOR_LLM="openai"
ROUTER_LLM="openai"
FILE_AGENT_LLM="openai"

# llm model
ANSWER_AGENT_MODEL="gpt-4o mini"
Expand All @@ -52,3 +56,4 @@ MATHS_AGENT_MODEL="gpt-4o mini"
WEB_AGENT_MODEL="gpt-4o mini"
CHART_GENERATOR_MODEL="gpt-4o mini"
ROUTER_MODEL="gpt-4o mini"
FILE_AGENT_MODEL="gpt-4o mini"
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ celerybeat.pid
# Environments
.env
.venv
files
env/
venv/
ENV/
Expand Down
2 changes: 1 addition & 1 deletion backend/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Follow the instructions below to run the backend locally. Change directory to `/
```bash
pip install -r requirements.txt
```

> (VsCode) You may run into some issues with compiling python packages from requirements.txt. To resolve this ensure you have downloaded and installed the "Desktop development with C++" workload from your Visual Studio installer.
3. Run the app

```bash
Expand Down
2 changes: 1 addition & 1 deletion backend/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
fastapi==0.110.0
uvicorn==0.29.0
mistralai==0.1.8
mistralai==1.1.0
pycodestyle==2.11.1
python-dotenv==1.0.1
neo4j==5.18.0
Expand Down
5 changes: 5 additions & 0 deletions backend/src/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from .validator_agent import ValidatorAgent
from .answer_agent import AnswerAgent
from .chart_generator_agent import ChartGeneratorAgent
from .file_agent import FileAgent
from .maths_agent import MathsAgent


config = Config()

Expand All @@ -32,6 +35,8 @@ def get_available_agents() -> List[Agent]:
return [DatastoreAgent(config.datastore_agent_llm, config.datastore_agent_model),
WebAgent(config.web_agent_llm, config.web_agent_model),
ChartGeneratorAgent(config.chart_generator_llm, config.chart_generator_model),
FileAgent(config.file_agent_llm, config.file_agent_model),
MathsAgent(config.maths_agent_llm, config.maths_agent_model),
]


Expand Down
1 change: 0 additions & 1 deletion backend/src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ async def __get_action(self, utterance: str) -> Action_and_args:

async def invoke(self, utterance: str) -> str:
(action, args) = await self.__get_action(utterance)
logger.info(f"USER - Action: {action} and args: {args} for utterance: {utterance}")
result_of_action = await action(**args, llm=self.llm, model=self.model)
await publish_log_info(LogPrefix.USER, f"Action gave result: {result_of_action}", __name__)
return result_of_action
Expand Down
25 changes: 17 additions & 8 deletions backend/src/agents/chart_generator_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
from src.utils import scratchpad
from PIL import Image
import json
# from src.websockets.user_confirmer import UserConfirmer
# from src.websockets.confirmations_manager import confirmations_manager

logger = logging.getLogger(__name__)

engine = PromptEngine()


async def generate_chart(question_intent, data_provided, question_params, llm: LLM, model) -> str:
details_to_generate_chart_code = engine.load_prompt(
"details-to-generate-chart-code",
Expand All @@ -28,13 +31,18 @@ async def generate_chart(question_intent, data_provided, question_params, llm: L
sanitised_script = sanitise_script(generated_code)

try:
# confirmer = UserConfirmer(confirmations_manager)
is_confirmed = True
# await confirmer.confirm("Would you like to generate a graph?")
if not is_confirmed:
raise Exception("The user did not confirm to creating a graph.")
local_vars = {}
exec(sanitised_script, {}, local_vars)
fig = local_vars.get('fig')
fig = local_vars.get("fig")
buf = BytesIO()
if fig is None:
raise ValueError("The generated code did not produce a figure named 'fig'.")
fig.savefig(buf, format='png')
fig.savefig(buf, format="png")
buf.seek(0)
with Image.open(buf):
image_data = base64.b64encode(buf.getvalue()).decode("utf-8")
Expand All @@ -44,7 +52,7 @@ async def generate_chart(question_intent, data_provided, question_params, llm: L
raise
response = {
"content": image_data,
"ignore_validation": "false"
"ignore_validation": "true",
}
return json.dumps(response, indent=4)

Expand All @@ -57,6 +65,7 @@ def sanitise_script(script: str) -> str:
script = script[:-3]
return script.strip()


@tool(
name="generate_code_chart",
description="Generate Matplotlib bar chart code if the user's query involves creating a chart",
Expand All @@ -74,18 +83,18 @@ def sanitise_script(script: str) -> str:
description="""
The specific parameters required for the question to be answered with the question_intent,
extracted from data_provided
"""),
}
""",
),
},
)

async def generate_code_chart(question_intent, data_provided, question_params, llm: LLM, model) -> str:
return await generate_chart(question_intent, data_provided, question_params, llm, model)


@agent(
name="ChartGeneratorAgent",
description="This agent is responsible for creating charts",
tools=[generate_code_chart]
tools=[generate_code_chart],
)

class ChartGeneratorAgent(Agent):
pass
86 changes: 55 additions & 31 deletions backend/src/agents/datastore_agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
from src.llm.llm import LLM
from src.utils.graph_db_utils import execute_query
Expand All @@ -8,15 +9,51 @@
from src.utils.log_publisher import LogPrefix, publish_log_info
from .agent import Agent, agent
from .tool import tool
import json

from src.utils.semantic_layer_builder import get_semantic_layer

logger = logging.getLogger(__name__)

engine = PromptEngine()

graph_schema = engine.load_prompt("graph-schema")
cache = {}

async def generate_cypher_query_core(
question_intent, operation, question_params, aggregation, sort_order, timeframe, llm: LLM, model
) -> str:

details_to_create_cypher_query = engine.load_prompt(
"details-to-create-cypher-query",
question_intent=question_intent,
operation=operation,
question_params=question_params,
aggregation=aggregation,
sort_order=sort_order,
timeframe=timeframe,
)
try:
graph_schema = await get_semantic_layer_cache(llm, model, cache)
graph_schema = json.dumps(graph_schema, separators=(",", ":"))

generate_cypher_query_prompt = engine.load_prompt(
"generate-cypher-query", graph_schema=graph_schema, current_date=datetime.now()
)

llm_query = await llm.chat(model, generate_cypher_query_prompt, details_to_create_cypher_query,
return_json=True)
json_query = to_json(llm_query)
await publish_log_info(LogPrefix.USER, f"Cypher generated by the LLM: {llm_query}", __name__)
if json_query["query"] == "None":
return "No database query"
db_response = execute_query(json_query["query"])
await publish_log_info(LogPrefix.USER, f"Database response: {db_response}", __name__)
except Exception as e:
logger.error(f"Error during data retrieval: {e}")
raise
response = {
"content": db_response,
"ignore_validation": "false"
}
return json.dumps(response, indent=4)

@tool(
name="generate cypher query",
Expand Down Expand Up @@ -51,39 +88,26 @@
),
},
)
async def generate_query(
question_intent, operation, question_params, aggregation, sort_order, timeframe, llm: LLM, model
) -> str:
details_to_create_cypher_query = engine.load_prompt(
"details-to-create-cypher-query",
question_intent=question_intent,
operation=operation,
question_params=question_params,
aggregation=aggregation,
sort_order=sort_order,
timeframe=timeframe,
)
generate_cypher_query_prompt = engine.load_prompt(
"generate-cypher-query", graph_schema=graph_schema, current_date=datetime.now()
)
llm_query = await llm.chat(model, generate_cypher_query_prompt, details_to_create_cypher_query, return_json=True)
json_query = to_json(llm_query)
await publish_log_info(LogPrefix.USER, f"Cypher generated by the LLM: {llm_query}", __name__)
if json_query["query"] == "None":
return "No database query"
db_response = execute_query(json_query["query"])
await publish_log_info(LogPrefix.USER, f"Database response: {db_response}", __name__)
response = {
"content": db_response,
"ignore_validation": "false"
}
return json.dumps(response, indent=4)

async def generate_cypher(question_intent, operation, question_params, aggregation, sort_order,
timeframe, llm: LLM, model) -> str:
return await generate_cypher_query_core(question_intent, operation, question_params, aggregation, sort_order,
timeframe, llm, model)


async def get_semantic_layer_cache(llm, model, graph_schema):
global cache
if not cache:
graph_schema = await get_semantic_layer(llm, model)
cache = graph_schema
return cache
else:
return cache

@agent(
name="DatastoreAgent",
description="This agent is responsible for handling database queries relating to the user's personal data.",
tools=[generate_query],
tools=[generate_cypher],
)
class DatastoreAgent(Agent):
pass
102 changes: 102 additions & 0 deletions backend/src/agents/file_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import logging
from .agent_types import Parameter
from .agent import Agent, agent
from .tool import tool
import json
import os
from src.utils.config import Config

logger = logging.getLogger(__name__)
config = Config()

FILES_DIRECTORY = f"/app/{config.files_directory}"

# Constants for response status
IGNORE_VALIDATION = "true"
STATUS_SUCCESS = "success"
STATUS_ERROR = "error"

# Utility function for error responses
def create_response(content: str, status: str = STATUS_SUCCESS) -> str:
return json.dumps({
"content": content,
"ignore_validation": IGNORE_VALIDATION,
"status": status
}, indent=4)

async def read_file_core(file_path: str) -> str:
full_path = os.path.normpath(os.path.join(FILES_DIRECTORY, file_path))
try:
with open(full_path, 'r') as file:
content = file.read()
return create_response(content)
except FileNotFoundError:
error_message = f"File {file_path} not found."
logger.error(error_message)
return create_response(error_message, STATUS_ERROR)
except Exception as e:
logger.error(f"Error reading file {full_path}: {e}")
return create_response(f"Error reading file: {file_path}", STATUS_ERROR)


async def write_or_update_file_core(file_path: str, content: str, update) -> str:
full_path = os.path.normpath(os.path.join(FILES_DIRECTORY, file_path))
try:
if update == "yes":
with open(full_path, 'a') as file:
file.write('\n' +content)
logger.info(f"Content appended to file {full_path} successfully.")
return create_response(f"Content appended to file {file_path}.")
else:
with open(full_path, 'w') as file:
file.write(content)
logger.info(f"Content written to file {full_path} successfully.")
return create_response(f"Content written to file {file_path}.")
except Exception as e:
logger.error(f"Error writing to file {full_path}: {e}")
return create_response(f"Error writing to file: {file_path}", STATUS_ERROR)


@tool(
name="read_file",
description="Read the content of a text file.",
parameters={
"file_path": Parameter(
type="string",
description="The path to the file to be read."
),
},
)
async def read_file(file_path: str, llm, model) -> str:
return await read_file_core(file_path)


@tool(
name="write_file",
description="Write or update content to a text file.",
parameters={
"file_path": Parameter(
type="string",
description="The path to the file where the content will be written."
),
"content": Parameter(
type="string",
description="The content to write to the file."
),
"update": Parameter(
type="string",
description="if yes then just append the file"
),
},
)
async def write_or_update_file(file_path: str, content: str, update, llm, model) -> str:
return await write_or_update_file_core(file_path, content, update)


@agent(
name="FileAgent",
description="This agent is responsible for reading from and writing to files.",
tools=[read_file, write_or_update_file],
)
class FileAgent(Agent):
pass
Loading

0 comments on commit 72e117e

Please sign in to comment.