Skip to content

Commit

Permalink
Address Ivans comments and fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
evpearce committed Dec 18, 2024
1 parent e465d44 commit 9d4803e
Show file tree
Hide file tree
Showing 12 changed files with 164 additions and 187 deletions.
17 changes: 5 additions & 12 deletions backend/src/agents/report_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,18 @@

class ReportAgent(Agent):
async def create_report(self, file: LLMFile, materiality_topics: dict[str, str]) -> str:
user_prompt = engine.load_prompt(
"create-report-user-prompt",
materiality_topics=materiality_topics
)

system_prompt = engine.load_prompt("create-report-system-prompt")

return await self.llm.chat_with_file(
self.model,
system_prompt=system_prompt,
user_prompt=user_prompt,
files=[file]
)
system_prompt=engine.load_prompt("create-report-system-prompt"),
user_prompt=engine.load_prompt("create-report-user-prompt", materiality_topics=materiality_topics),
files=[file],
)

async def get_company_name(self, file: LLMFile) -> str:
response = await self.llm.chat_with_file(
self.model,
system_prompt=engine.load_prompt("find-company-name-from-file-system-prompt"),
user_prompt=engine.load_prompt("find-company-name-from-file-user-prompt"),
files=[file]
files=[file],
)
return json.loads(response)["company_name"]
12 changes: 7 additions & 5 deletions backend/src/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from fastapi.middleware.cors import CORSMiddleware
from src.utils.scratchpad import ScratchPadMiddleware
from src.session.chat_response import get_session_chat_response_ids
from src.chat_storage_service import clear_chat_messages, get_chat_message
from src.directors.report_director import report_on_file_upload
from src.chat_storage_service import clear_chat_messages, get_chat_message, get_chat_message

Check failure on line 10 in backend/src/api/app.py

View workflow job for this annotation

GitHub Actions / Linting Backend

Ruff (F811)

backend/src/api/app.py:10:77: F811 Redefinition of unused `get_chat_message` from line 10
from src.directors.report_director import create_report_from_file
from src.session.file_uploads import clear_session_file_uploads, get_report
from src.session.redis_session_middleware import reset_session
from src.utils import Config, test_connection
Expand Down Expand Up @@ -129,27 +129,29 @@ async def suggestions():
async def report(file: UploadFile):
logger.info(f"upload file type={file.content_type} name={file.filename} size={file.size}")
try:
processed_upload = await report_on_file_upload(file)
processed_upload = await create_report_from_file(file)
return JSONResponse(status_code=200, content=processed_upload)
except HTTPException as he:
raise he
except Exception as e:
logger.exception(e)
return JSONResponse(status_code=500, content=file_upload_failed_response)


@app.get("/report/{id}")
def download_report(id: str):
logger.info(f"Get report download called for id: {id}")
try:
final_result = get_report(id)
if final_result is None:
return JSONResponse(status_code=404, content=f"Message with id {id} not found")
headers = {'Content-Disposition': 'attachment; filename="report.md"'}
return Response(final_result.get("report"), headers=headers, media_type='text/markdown')
headers = {"Content-Disposition": 'attachment; filename="report.md"'}
return Response(final_result.get("report"), headers=headers, media_type="text/markdown")
except Exception as e:
logger.exception(e)
return JSONResponse(status_code=500, content=report_get_upload_failed_response)


@app.get("/uploadfile")
async def fetch_file(id: str):
logger.info(f"fetch uploaded file id={id} ")
Expand Down
35 changes: 18 additions & 17 deletions backend/src/directors/report_director.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,49 @@
import logging
import sys
import uuid
from fastapi import UploadFile
from fastapi import UploadFile, HTTPException

from src.llm.llm import LLMFile
from src.session.file_uploads import FileUploadReport, store_report
from src.session.file_uploads import ReportResponse, store_report
from src.agents import get_report_agent, get_materiality_agent

logger = logging.getLogger(__name__)
MAX_FILE_SIZE = 10 * 1024 * 1024

async def report_on_file_upload(upload: UploadFile) -> FileUploadReport:

async def create_report_from_file(upload: UploadFile) -> ReportResponse:
file_stream = await upload.read()
if upload.filename is None:
raise ValueError("Filename cannot be None")
if upload.filename is None or upload.filename == "":
raise HTTPException(status_code=400, detail="Filename missing from file upload")

file_size = sys.getsizeof(file_stream)

if file_size > MAX_FILE_SIZE:
raise HTTPException(status_code=413, detail=f"File upload must be less than {MAX_FILE_SIZE} bytes")

file = LLMFile(file_name=upload.filename, file=file_stream)
file_id = str(uuid.uuid4())

report_agent = get_report_agent()

company_name = await report_agent.get_company_name(file)
logger.info(f"Company name: {company_name}")

topics = await get_materiality_agent().list_material_topics(company_name)

logger.info(f"Topics are: {topics}")
report = await report_agent.create_report(file, topics)
logger.info(f"Report: {report}")

report_upload = FileUploadReport(
report_response = ReportResponse(
filename=file.file_name,
id=file_id,
report=report,
answer=create_report_chat_message(file.file_name, company_name, topics)
answer=create_report_chat_message(file.file_name, company_name, topics),
)

store_report(report_upload)
store_report(report_response)

return report_upload
return report_response


def create_report_chat_message(file_name: str, company_name: str, topics: dict[str, str]) -> str:
topics_with_markdown = [
f"{key}\n{value}" for key, value in topics.items()
]
topics_with_markdown = [f"{key}\n{value}" for key, value in topics.items()]
return f"""Your report for {file_name} is ready to view.
The following materiality topics were identified for {company_name} which the report focuses on:
Expand Down
31 changes: 5 additions & 26 deletions backend/src/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def remove_citations(message: Text):


class OpenAI(LLM):

async def chat(self, model, system_prompt: str, user_prompt: str, return_json=False) -> str:
logger.debug(
"##### Called open ai chat ... llm. Waiting on response model with prompt {0}.".format(
Expand All @@ -35,7 +34,7 @@ async def chat(self, model, system_prompt: str, user_prompt: str, return_json=Fa
{"role": "user", "content": user_prompt},
],
temperature=0,
response_format={"type": "json_object"} if return_json else NOT_GIVEN
response_format={"type": "json_object"} if return_json else NOT_GIVEN,
)
content = response.choices[0].message.content
logger.info(f"OpenAI response: Finish reason: {response.choices[0].finish_reason}, Content: {content}")
Expand All @@ -50,13 +49,7 @@ async def chat(self, model, system_prompt: str, user_prompt: str, return_json=Fa
logger.error(f"Error calling OpenAI model: {e}")
return "An error occurred while processing the request."

async def chat_with_file(
self,
model: str,
system_prompt: str,
user_prompt: str,
files: list[LLMFile]
) -> str:
async def chat_with_file(self, model: str, system_prompt: str, user_prompt: str, files: list[LLMFile]) -> str:
client = AsyncOpenAI(api_key=config.openai_key)
file_ids = await self.__upload_files(files)

Expand All @@ -72,17 +65,12 @@ async def chat_with_file(
{
"role": "user",
"content": user_prompt,
"attachments": [
{"file_id": file_id, "tools": [{"type": "file_search"}]}
for file_id in file_ids
],
"attachments": [{"file_id": file_id, "tools": [{"type": "file_search"}]} for file_id in file_ids],
}
]
)

run = await client.beta.threads.runs.create_and_poll(
thread_id=thread.id, assistant_id=file_assistant.id
)
run = await client.beta.threads.runs.create_and_poll(thread_id=thread.id, assistant_id=file_assistant.id)

messages = await client.beta.threads.messages.list(thread_id=thread.id, run_id=run.id)

Expand All @@ -100,16 +88,7 @@ async def __upload_files(self, files: list[LLMFile]) -> list[str]:
file_ids = []
for file in files:
logger.info(f"Uploading file '{file.file_name}' to OpenAI")
if isinstance(file.file, (PathLike, str)):
file_path = Path(file.file)
with file_path.open("rb") as f:
file_bytes = f.read()
elif isinstance(file.file, bytes):
file_bytes = file.file
else:
logger.error(f"Unsupported file type for '{file.file_name}'")
continue
file = await client.files.create(file=(file.file_name, file_bytes), purpose="assistants")
file = await client.files.create(file=(file.file_name, file.file), purpose="assistants")
file_ids.append(file.id)

return file_ids
6 changes: 3 additions & 3 deletions backend/src/session/file_uploads.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class FileUpload(TypedDict):
size: Optional[int]


class FileUploadReport(TypedDict):
class ReportResponse(TypedDict):
id: str
answer: str
filename: Optional[str]
Expand Down Expand Up @@ -83,9 +83,9 @@ def clear_session_file_uploads():
set_session(UPLOADS_META_SESSION_KEY, [])


def store_report(report: FileUploadReport):
def store_report(report: ReportResponse):
redis_client.set(REPORT_KEY_PREFIX + report["id"], json.dumps(report))


def get_report(id: str) -> FileUploadReport | None:
def get_report(id: str) -> ReportResponse | None:
return _get_key(REPORT_KEY_PREFIX + id)
22 changes: 4 additions & 18 deletions backend/src/utils/file_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from io import BytesIO, TextIOWrapper
from pathlib import Path
import sys
import time
from fastapi import HTTPException
import logging
Expand All @@ -11,8 +12,6 @@

logger = logging.getLogger(__name__)

MAX_FILE_SIZE = 10*1024*1024


def handle_file_upload(file: LLMFile) -> FileUpload:
if isinstance(file.file, (PathLike, str)):
Expand All @@ -22,19 +21,10 @@ def handle_file_upload(file: LLMFile) -> FileUpload:
elif isinstance(file.file, bytes):
file_bytes = file.file
else:
raise HTTPException(
status_code=400,
detail="File must be provided as bytes or a valid file path."
)
raise HTTPException(status_code=400, detail="File must be provided as bytes or a valid file path.")

file_stream = BytesIO(file_bytes)
file_size = len(file_bytes)

if file_size > MAX_FILE_SIZE:
raise HTTPException(
status_code=413,
detail=f"File upload must be less than {MAX_FILE_SIZE} bytes"
)
file_size = sys.getsizeof(file_bytes)

all_content = ""
content_type = "unknown"
Expand Down Expand Up @@ -63,8 +53,7 @@ def handle_file_upload(file: LLMFile) -> FileUpload:

except Exception as text_error:
raise HTTPException(
status_code=400,
detail="File upload must be a supported type text or pdf"
status_code=400, detail="File upload must be a supported type text or pdf"
) from text_error

session_file = FileUpload(
Expand All @@ -82,6 +71,3 @@ def handle_file_upload(file: LLMFile) -> FileUpload:

def get_file_upload(upload_id) -> FileUpload | None:
return get_session_file_upload(upload_id)



23 changes: 0 additions & 23 deletions backend/tests/agents/report_agent_test.py

This file was deleted.

14 changes: 7 additions & 7 deletions backend/tests/api/app_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from fastapi.testclient import TestClient
import pytest
from src.chat_storage_service import ChatResponse
from src.directors.report_director import FileUploadReport
from src.directors.report_director import ReportResponse
from src.api import app, healthy_response, unhealthy_neo4j_response, chat_fail_response

client = TestClient(app)
Expand Down Expand Up @@ -87,14 +87,14 @@ def test_chat_message_not_found(mocker):


def test_report_response_success(mocker):
mock_response = FileUploadReport(filename="filename", id="1", report="some report md", answer="chat message")
mock_report = mocker.patch("src.api.app.report_on_file_upload", return_value=mock_response)
mock_response = ReportResponse(filename="filename", id="1", report="some report md", answer="chat message")
mock_report = mocker.patch("src.api.app.create_report_from_file", return_value=mock_response)

response = client.post("/report", files={"file": ("filename", "test data".encode("utf-8"), "text/plain")})

mock_report.assert_called_once()
assert response.status_code == 200
assert response.json() == {'filename': 'filename', 'id': '1', 'report': 'some report md', 'answer': 'chat message'}
assert response.json() == {"filename": "filename", "id": "1", "report": "some report md", "answer": "chat message"}


@pytest.mark.asyncio
Expand All @@ -106,15 +106,15 @@ async def test_lifespan_populates_db(mocker) -> None:


def test_get_report_success(mocker):
report = FileUploadReport(id="12", filename="test.pdf", report="test report", answer='chat message')
report = ReportResponse(id="12", filename="test.pdf", report="test report", answer="chat message")
mock_get_report = mocker.patch("src.api.app.get_report", return_value=report)

response = client.get("/report/12")

mock_get_report.assert_called_with("12")
assert response.status_code == 200
assert response.headers.get('Content-Disposition') == 'attachment; filename="report.md"'
assert response.headers.get('Content-Type') == 'text/markdown; charset=utf-8'
assert response.headers.get("Content-Disposition") == 'attachment; filename="report.md"'
assert response.headers.get("Content-Type") == "text/markdown; charset=utf-8"


def test_get_report_not_found(mocker):
Expand Down
Loading

0 comments on commit 9d4803e

Please sign in to comment.