Skip to content

Commit

Permalink
fix tests and linting
Browse files Browse the repository at this point in the history
  • Loading branch information
IMladjenovic committed Dec 20, 2024
1 parent 688533e commit 195f990
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 32 deletions.
2 changes: 1 addition & 1 deletion backend/src/agents/agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC
import json
import logging
from typing import List, Type, TypeVar, Optional
from typing import List, Type, TypeVar
from src.llm import LLM, get_llm
from src.utils.log_publisher import LogPrefix, publish_log_info

Expand Down
2 changes: 1 addition & 1 deletion backend/src/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,6 @@ async def delete_all_files(self):
logger.info(f"Open AI: deleting files {files}")
delete_tasks = [client.files.delete(file_id=file["file_id"]) for file in files]
await asyncio.gather(*delete_tasks)
logger.info(f"Open AI: Files deleted")
logger.info("Open AI: Files deleted")
except OpenAIError:
logger.info("OpenAI not configured")
3 changes: 2 additions & 1 deletion backend/src/session/redis_session_middleware.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from typing import Optional
from uuid import uuid4
import redis
from src.utils import test_redis_connection
Expand Down Expand Up @@ -53,7 +54,7 @@ def ignore_request(request: Request) -> bool:
return request.url.path == '/health' or request.method == 'OPTIONS'


def get_session(key: str, default: list = None):
def get_session(key: str, default: Optional[list] = None):
if not default:
default = []
request: Request = request_context.get()
Expand Down
33 changes: 9 additions & 24 deletions backend/tests/agents/generalist_agent_test.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,20 @@
import pytest
from unittest.mock import patch, AsyncMock
import json
from src.agents.generalist_agent import GeneralistAgent

from src.llm.factory import get_llm
from src.agents.generalist_agent import GeneralistAgent

@pytest.mark.asyncio
@patch("src.agents.generalist_agent.answer_user_question", new_callable=AsyncMock)
async def test_generalist_agent(
mock_answer_user_question,
):
mock_answer_user_question.return_value = json.dumps(
{"status": "success", "response": json.dumps({"is_valid": True, "answer": "Example summary."})}
)
generalist_agent = GeneralistAgent("llm", "mock_model")

result = await generalist_agent.invoke("example query")
expected_response = {"content": "Example summary.", "ignore_validation": "false"}
assert json.loads(result) == expected_response
mock_model = "mockmodel"
mock_llm = get_llm("mockllm")


@pytest.mark.asyncio
@patch("src.agents.generalist_agent.answer_user_question", new_callable=AsyncMock)
async def test_generalist_agent_reponse_format_error(
mock_answer_user_question,
):
mock_answer_user_question.return_value = json.dumps(
{"status": "success", "response": json.dumps({"is_valid": True, "answer_wrong_format": "Example summary."})}
)
generalist_agent = GeneralistAgent("llm", "mock_model")
async def test_generalist_agent(mocker):
mock_llm.chat = mocker.AsyncMock(return_value="Example summary.")

result = await generalist_agent.invoke("example query")
agent = GeneralistAgent(llm_name="mockllm", model=mock_model)

expected_response = {"content": "Error in answer format.", "ignore_validation": "false"}
result = await agent.invoke("example query")
expected_response = {"content": "Example summary.", "ignore_validation": "false"}
assert json.loads(result) == expected_response
2 changes: 1 addition & 1 deletion backend/tests/agents/materiality_agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@ async def test_invoke_calls_llm(mocker):
mock_llm.chat = mocker.AsyncMock(return_value=json.dumps(mock_selected_files))
mock_llm.chat_with_file = mocker.AsyncMock(return_value=json.dumps(mock_materiality_topics))

response = await agent.list_material_topics("AstraZeneca")
response = await agent.list_material_topics_for_company("AstraZeneca")

assert response == mock_materiality_topics["material_topics"]
4 changes: 2 additions & 2 deletions backend/tests/directors/report_director_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ async def test_create_report_from_file(mocker):

# Mock materiality agent
mock_materiality_agent = mocker.AsyncMock()
mock_materiality_agent.list_material_topics.return_value = mock_topics
mock_materiality_agent.list_material_topics_for_company.return_value = mock_topics
mocker.patch("src.directors.report_director.get_materiality_agent", return_value=mock_materiality_agent)

mock_store_report = mocker.patch("src.directors.report_director.store_report", return_value=file_upload)
Expand All @@ -46,7 +46,7 @@ async def test_create_report_from_file(mocker):

mock_store_report.assert_called_once_with(expected_response)

mock_materiality_agent.list_material_topics.assert_called_once_with("CompanyABC")
mock_materiality_agent.list_material_topics_for_company.assert_called_once_with("CompanyABC")

assert response == expected_response

Expand Down
4 changes: 2 additions & 2 deletions backend/tests/session/test_redis_session_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,5 +74,5 @@ def test_reset_session(mocker, mock_request_context):
assert get_session("key2") == "value2"

reset_session()
assert get_session("key1", None) is None
assert get_session("key2", None) is None
assert get_session("key1") == []
assert get_session("key2") == []

0 comments on commit 195f990

Please sign in to comment.