From c55973b37bb0112db40a319800c82a31202a55a0 Mon Sep 17 00:00:00 2001 From: "Ivan Mladjenovic (He/Him)" Date: Fri, 13 Dec 2024 11:08:12 +0000 Subject: [PATCH] further type check fixes --- backend/src/agents/agent.py | 4 ++-- backend/src/directors/report_director.py | 10 ---------- backend/src/llm/mistral.py | 13 ++++++++++--- backend/src/llm/openai.py | 10 +++++++--- backend/src/session/file_uploads.py | 23 ++++++++++++++--------- backend/tests/api/app_test.py | 15 +++++++++++---- 6 files changed, 44 insertions(+), 31 deletions(-) diff --git a/backend/src/agents/agent.py b/backend/src/agents/agent.py index 052ad3b2..7389a0c0 100644 --- a/backend/src/agents/agent.py +++ b/backend/src/agents/agent.py @@ -1,7 +1,7 @@ from abc import ABC import json import logging -from typing import List, Type, Union, TypeVar +from typing import List, Type, Union, TypeVar, Optional from src.llm import LLM, get_llm from src.utils.log_publisher import LogPrefix, publish_log_info @@ -67,7 +67,7 @@ async def invoke(self, utterance: str) -> str: T = TypeVar('T', bound=Agent) -def agent(name: str, description: str, tools: List[Tool] = None): +def agent(name: str, description: str, tools: Optional[List[Tool]] = None): def decorator(agent: Type[T]) -> Type[T]: agent.name = name agent.description = description diff --git a/backend/src/directors/report_director.py b/backend/src/directors/report_director.py index 431db152..f3f4a18c 100644 --- a/backend/src/directors/report_director.py +++ b/backend/src/directors/report_director.py @@ -1,20 +1,10 @@ -from typing import TypedDict, Optional from fastapi import UploadFile -from dataclasses import dataclass from src.session.file_uploads import FileUploadReport, store_report from src.utils.file_utils import handle_file_upload from src.agents import get_report_agent, get_materiality_agent -@dataclass -class FileUploadReport(TypedDict): - id: str - answer: str - filename: Optional[str] - report: Optional[str] - - async def report_on_file_upload(upload: UploadFile) -> FileUploadReport: file = handle_file_upload(upload) diff --git a/backend/src/llm/mistral.py b/backend/src/llm/mistral.py index 324dcdf9..93d8143d 100644 --- a/backend/src/llm/mistral.py +++ b/backend/src/llm/mistral.py @@ -1,9 +1,9 @@ -from typing import Coroutine +from typing import Coroutine, Optional from mistralai import Mistral as MistralApi, UserMessage, SystemMessage import logging from src.utils import Config -from .llm import LLM +from .llm import LLM, LLMFileFromPath, LLMFileFromBytes logger = logging.getLogger(__name__) config = Config() @@ -35,5 +35,12 @@ async def chat(self, model, system_prompt: str, user_prompt: str, return_json=Fa logger.debug('{0} response : "{1}"'.format(model, content)) return content - def chat_with_file(self, model: str, system_prompt: str, user_prompt: str, file_paths: list[str]) -> Coroutine: + def chat_with_file( + self, + model: str, + system_prompt: str, + user_prompt: str, + files_by_path: Optional[list[LLMFileFromPath]] = None, + files_by_stream: Optional[list[LLMFileFromBytes]] = None + ) -> Coroutine: raise Exception("Mistral does not support chat_with_file") diff --git a/backend/src/llm/openai.py b/backend/src/llm/openai.py index 435eaa20..5370c15d 100644 --- a/backend/src/llm/openai.py +++ b/backend/src/llm/openai.py @@ -4,7 +4,7 @@ from src.utils import Config from src.llm import LLM, LLMFileFromPath, LLMFileFromBytes from openai import NOT_GIVEN, AsyncOpenAI -from openai.types.beta.threads import Text +from openai.types.beta.threads import Text, TextContentBlock logger = logging.getLogger(__name__) config = Config() @@ -86,10 +86,14 @@ async def chat_with_file( messages = await client.beta.threads.messages.list(thread_id=thread.id, run_id=run.id) - message = messages.data[0].content[0].text + if isinstance(messages.data[0].content[0], TextContentBlock): + message = remove_citations(messages.data[0].content[0].text) + else: + message = messages.data[0].content[0].to_json() logger.info(f"OpenAI response: {message}") - return remove_citations(message) + return message + async def __upload_files( self, diff --git a/backend/src/session/file_uploads.py b/backend/src/session/file_uploads.py index 9e6dcb4d..e74d965a 100644 --- a/backend/src/session/file_uploads.py +++ b/backend/src/session/file_uploads.py @@ -1,5 +1,6 @@ import json -from typing import TypedDict +from dataclasses import dataclass +from typing import TypedDict, Optional import logging import redis @@ -27,14 +28,18 @@ class FileUploadMeta(TypedDict): class FileUpload(TypedDict): uploadId: str content: str - filename: str | None - contentType: str | None - size: int | None + filename: str + contentType: Optional[str] + size: Optional[int] + -class FileUploadReport(TypedDict): +@dataclass +class FileUploadReport: id: str - filename: str | None - report: str | None + answer: str + filename: Optional[str] + report: Optional[str] + def get_session_file_uploads_meta() -> list[FileUploadMeta] | None: return get_session(UPLOADS_META_SESSION_KEY, []) @@ -52,7 +57,7 @@ def get_session_file_upload(upload_id) -> FileUpload | None: return _get_key(UPLOADS_KEY_PREFIX + upload_id) -def update_session_file_uploads(file_upload:FileUpload): +def update_session_file_uploads(file_upload: FileUpload): file_uploads_meta_session = get_session(UPLOADS_META_SESSION_KEY, []) if not file_uploads_meta_session: # initialise the session object @@ -80,7 +85,7 @@ def clear_session_file_uploads(): set_session(UPLOADS_META_SESSION_KEY, []) -def store_report(report:FileUploadReport): +def store_report(report: FileUploadReport): redis_client.set(REPORT_KEY_PREFIX + report["id"], json.dumps(report)) diff --git a/backend/tests/api/app_test.py b/backend/tests/api/app_test.py index 68420c3c..2fb72e74 100644 --- a/backend/tests/api/app_test.py +++ b/backend/tests/api/app_test.py @@ -49,6 +49,7 @@ def test_chat_response_failure(mocker): assert response.status_code == 500 assert response.json() == chat_fail_response + def test_chat_delete(mocker): mock_reset_session = mocker.patch("src.api.app.reset_session") mock_clear_files = mocker.patch("src.api.app.clear_session_file_uploads") @@ -60,6 +61,7 @@ def test_chat_delete(mocker): assert response.status_code == 204 + def test_chat_message_success(mocker): message = ChatResponse(id="1", question="Question", answer="Answer", reasoning="Reasoning", dataset="dataset") mock_get_chat_message = mocker.patch("src.api.app.get_chat_message", return_value=message) @@ -70,6 +72,7 @@ def test_chat_message_success(mocker): assert response.status_code == 200 assert response.json() == message + def test_chat_message_not_found(mocker): mock_get_chat_message = mocker.patch("src.api.app.get_chat_message", return_value=None) @@ -78,15 +81,17 @@ def test_chat_message_not_found(mocker): mock_get_chat_message.assert_called_with("123") assert response.status_code == 404 + def test_report_response_success(mocker): - mock_reponse = FileUploadReport(filename="filename", id="1", report="some report md") - mock_report = mocker.patch("src.api.app.report_on_file_upload", return_value=mock_reponse) + mock_response = FileUploadReport(filename="filename", id="1", report="some report md", answer="chat message") + mock_report = mocker.patch("src.api.app.report_on_file_upload", return_value=mock_response) response = client.post("/report", files={"file": ("filename", "test data".encode("utf-8"), "text/plain")}) mock_report.assert_called_once() assert response.status_code == 200 - assert response.json() == {'filename': 'filename', 'id': '1', 'report': 'some report md'} + assert response.json() == {'filename': 'filename', 'id': '1', 'report': 'some report md', 'answer': 'chat message'} + @pytest.mark.asyncio async def test_lifespan_populates_db(mocker) -> None: @@ -95,8 +100,9 @@ async def test_lifespan_populates_db(mocker) -> None: with client: mock_dataset_upload.assert_called_once_with() + def test_get_report_success(mocker): - report = FileUploadReport(id="12", filename="test.pdf", report="test report") + report = FileUploadReport(id="12", filename="test.pdf", report="test report", answer='chat message') mock_get_report = mocker.patch("src.api.app.get_report", return_value=report) response = client.get("/report/12") @@ -106,6 +112,7 @@ def test_get_report_success(mocker): assert response.headers.get('Content-Disposition') == 'attachment; filename="report.md"' assert response.headers.get('Content-Type') == 'text/markdown; charset=utf-8' + def test_get_report_not_found(mocker): mock_get_report = mocker.patch("src.api.app.get_report", return_value=None)