Skip to content

Commit

Permalink
further type check fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
IMladjenovic committed Dec 13, 2024
1 parent 00474ed commit c55973b
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 31 deletions.
4 changes: 2 additions & 2 deletions backend/src/agents/agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC
import json
import logging
from typing import List, Type, Union, TypeVar
from typing import List, Type, Union, TypeVar, Optional

Check failure on line 4 in backend/src/agents/agent.py

View workflow job for this annotation

GitHub Actions / Linting Backend

Ruff (F401)

backend/src/agents/agent.py:4:32: F401 `typing.Union` imported but unused
from src.llm import LLM, get_llm
from src.utils.log_publisher import LogPrefix, publish_log_info

Expand Down Expand Up @@ -67,7 +67,7 @@ async def invoke(self, utterance: str) -> str:
T = TypeVar('T', bound=Agent)


def agent(name: str, description: str, tools: List[Tool] = None):
def agent(name: str, description: str, tools: Optional[List[Tool]] = None):
def decorator(agent: Type[T]) -> Type[T]:
agent.name = name
agent.description = description
Expand Down
10 changes: 0 additions & 10 deletions backend/src/directors/report_director.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,10 @@
from typing import TypedDict, Optional
from fastapi import UploadFile
from dataclasses import dataclass

from src.session.file_uploads import FileUploadReport, store_report
from src.utils.file_utils import handle_file_upload
from src.agents import get_report_agent, get_materiality_agent


@dataclass
class FileUploadReport(TypedDict):
id: str
answer: str
filename: Optional[str]
report: Optional[str]


async def report_on_file_upload(upload: UploadFile) -> FileUploadReport:

file = handle_file_upload(upload)
Expand Down
13 changes: 10 additions & 3 deletions backend/src/llm/mistral.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Coroutine
from typing import Coroutine, Optional

from mistralai import Mistral as MistralApi, UserMessage, SystemMessage
import logging
from src.utils import Config
from .llm import LLM
from .llm import LLM, LLMFileFromPath, LLMFileFromBytes

logger = logging.getLogger(__name__)
config = Config()
Expand Down Expand Up @@ -35,5 +35,12 @@ async def chat(self, model, system_prompt: str, user_prompt: str, return_json=Fa
logger.debug('{0} response : "{1}"'.format(model, content))
return content

def chat_with_file(self, model: str, system_prompt: str, user_prompt: str, file_paths: list[str]) -> Coroutine:
def chat_with_file(
self,
model: str,
system_prompt: str,
user_prompt: str,
files_by_path: Optional[list[LLMFileFromPath]] = None,
files_by_stream: Optional[list[LLMFileFromBytes]] = None
) -> Coroutine:
raise Exception("Mistral does not support chat_with_file")
10 changes: 7 additions & 3 deletions backend/src/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from src.utils import Config
from src.llm import LLM, LLMFileFromPath, LLMFileFromBytes
from openai import NOT_GIVEN, AsyncOpenAI
from openai.types.beta.threads import Text
from openai.types.beta.threads import Text, TextContentBlock

logger = logging.getLogger(__name__)
config = Config()
Expand Down Expand Up @@ -86,10 +86,14 @@ async def chat_with_file(

messages = await client.beta.threads.messages.list(thread_id=thread.id, run_id=run.id)

message = messages.data[0].content[0].text
if isinstance(messages.data[0].content[0], TextContentBlock):
message = remove_citations(messages.data[0].content[0].text)
else:
message = messages.data[0].content[0].to_json()

logger.info(f"OpenAI response: {message}")
return remove_citations(message)
return message


async def __upload_files(
self,
Expand Down
23 changes: 14 additions & 9 deletions backend/src/session/file_uploads.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from typing import TypedDict
from dataclasses import dataclass
from typing import TypedDict, Optional
import logging
import redis

Expand Down Expand Up @@ -27,14 +28,18 @@ class FileUploadMeta(TypedDict):
class FileUpload(TypedDict):
uploadId: str
content: str
filename: str | None
contentType: str | None
size: int | None
filename: str
contentType: Optional[str]
size: Optional[int]


class FileUploadReport(TypedDict):
@dataclass
class FileUploadReport:
id: str
filename: str | None
report: str | None
answer: str
filename: Optional[str]
report: Optional[str]


def get_session_file_uploads_meta() -> list[FileUploadMeta] | None:
return get_session(UPLOADS_META_SESSION_KEY, [])
Expand All @@ -52,7 +57,7 @@ def get_session_file_upload(upload_id) -> FileUpload | None:
return _get_key(UPLOADS_KEY_PREFIX + upload_id)


def update_session_file_uploads(file_upload:FileUpload):
def update_session_file_uploads(file_upload: FileUpload):
file_uploads_meta_session = get_session(UPLOADS_META_SESSION_KEY, [])
if not file_uploads_meta_session:
# initialise the session object
Expand Down Expand Up @@ -80,7 +85,7 @@ def clear_session_file_uploads():
set_session(UPLOADS_META_SESSION_KEY, [])


def store_report(report:FileUploadReport):
def store_report(report: FileUploadReport):
redis_client.set(REPORT_KEY_PREFIX + report["id"], json.dumps(report))

Check failure on line 89 in backend/src/session/file_uploads.py

View workflow job for this annotation

GitHub Actions / Type Checking Backend

"__getitem__" method not defined on type "FileUploadReport" (reportIndexIssue)


Expand Down
15 changes: 11 additions & 4 deletions backend/tests/api/app_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def test_chat_response_failure(mocker):
assert response.status_code == 500
assert response.json() == chat_fail_response


def test_chat_delete(mocker):
mock_reset_session = mocker.patch("src.api.app.reset_session")
mock_clear_files = mocker.patch("src.api.app.clear_session_file_uploads")
Expand All @@ -60,6 +61,7 @@ def test_chat_delete(mocker):

assert response.status_code == 204


def test_chat_message_success(mocker):
message = ChatResponse(id="1", question="Question", answer="Answer", reasoning="Reasoning", dataset="dataset")
mock_get_chat_message = mocker.patch("src.api.app.get_chat_message", return_value=message)
Expand All @@ -70,6 +72,7 @@ def test_chat_message_success(mocker):
assert response.status_code == 200
assert response.json() == message


def test_chat_message_not_found(mocker):
mock_get_chat_message = mocker.patch("src.api.app.get_chat_message", return_value=None)

Expand All @@ -78,15 +81,17 @@ def test_chat_message_not_found(mocker):
mock_get_chat_message.assert_called_with("123")
assert response.status_code == 404


def test_report_response_success(mocker):
mock_reponse = FileUploadReport(filename="filename", id="1", report="some report md")
mock_report = mocker.patch("src.api.app.report_on_file_upload", return_value=mock_reponse)
mock_response = FileUploadReport(filename="filename", id="1", report="some report md", answer="chat message")
mock_report = mocker.patch("src.api.app.report_on_file_upload", return_value=mock_response)

response = client.post("/report", files={"file": ("filename", "test data".encode("utf-8"), "text/plain")})

mock_report.assert_called_once()
assert response.status_code == 200
assert response.json() == {'filename': 'filename', 'id': '1', 'report': 'some report md'}
assert response.json() == {'filename': 'filename', 'id': '1', 'report': 'some report md', 'answer': 'chat message'}


@pytest.mark.asyncio
async def test_lifespan_populates_db(mocker) -> None:
Expand All @@ -95,8 +100,9 @@ async def test_lifespan_populates_db(mocker) -> None:
with client:
mock_dataset_upload.assert_called_once_with()


def test_get_report_success(mocker):
report = FileUploadReport(id="12", filename="test.pdf", report="test report")
report = FileUploadReport(id="12", filename="test.pdf", report="test report", answer='chat message')
mock_get_report = mocker.patch("src.api.app.get_report", return_value=report)

response = client.get("/report/12")
Expand All @@ -106,6 +112,7 @@ def test_get_report_success(mocker):
assert response.headers.get('Content-Disposition') == 'attachment; filename="report.md"'
assert response.headers.get('Content-Type') == 'text/markdown; charset=utf-8'


def test_get_report_not_found(mocker):
mock_get_report = mocker.patch("src.api.app.get_report", return_value=None)

Expand Down

0 comments on commit c55973b

Please sign in to comment.