From 195f9902c61135deb5d97c01120461cada8675ec Mon Sep 17 00:00:00 2001 From: "Ivan Mladjenovic (He/Him)" Date: Fri, 20 Dec 2024 12:03:42 +0000 Subject: [PATCH] fix tests and linting --- backend/src/agents/agent.py | 2 +- backend/src/llm/openai.py | 2 +- .../src/session/redis_session_middleware.py | 3 +- backend/tests/agents/generalist_agent_test.py | 33 +++++-------------- .../tests/agents/materiality_agent_test.py | 2 +- .../tests/directors/report_director_test.py | 4 +-- .../session/test_redis_session_middleware.py | 4 +-- 7 files changed, 18 insertions(+), 32 deletions(-) diff --git a/backend/src/agents/agent.py b/backend/src/agents/agent.py index 88cb693b..ad29dbcf 100644 --- a/backend/src/agents/agent.py +++ b/backend/src/agents/agent.py @@ -1,7 +1,7 @@ from abc import ABC import json import logging -from typing import List, Type, TypeVar, Optional +from typing import List, Type, TypeVar from src.llm import LLM, get_llm from src.utils.log_publisher import LogPrefix, publish_log_info diff --git a/backend/src/llm/openai.py b/backend/src/llm/openai.py index fac1134d..d7ca0b76 100644 --- a/backend/src/llm/openai.py +++ b/backend/src/llm/openai.py @@ -123,6 +123,6 @@ async def delete_all_files(self): logger.info(f"Open AI: deleting files {files}") delete_tasks = [client.files.delete(file_id=file["file_id"]) for file in files] await asyncio.gather(*delete_tasks) - logger.info(f"Open AI: Files deleted") + logger.info("Open AI: Files deleted") except OpenAIError: logger.info("OpenAI not configured") diff --git a/backend/src/session/redis_session_middleware.py b/backend/src/session/redis_session_middleware.py index 2dab1cb1..207eb3f5 100644 --- a/backend/src/session/redis_session_middleware.py +++ b/backend/src/session/redis_session_middleware.py @@ -1,4 +1,5 @@ import json +from typing import Optional from uuid import uuid4 import redis from src.utils import test_redis_connection @@ -53,7 +54,7 @@ def ignore_request(request: Request) -> bool: return request.url.path == '/health' or request.method == 'OPTIONS' -def get_session(key: str, default: list = None): +def get_session(key: str, default: Optional[list] = None): if not default: default = [] request: Request = request_context.get() diff --git a/backend/tests/agents/generalist_agent_test.py b/backend/tests/agents/generalist_agent_test.py index 2d67890f..7d644f3a 100644 --- a/backend/tests/agents/generalist_agent_test.py +++ b/backend/tests/agents/generalist_agent_test.py @@ -1,35 +1,20 @@ import pytest -from unittest.mock import patch, AsyncMock import json -from src.agents.generalist_agent import GeneralistAgent +from src.llm.factory import get_llm +from src.agents.generalist_agent import GeneralistAgent -@pytest.mark.asyncio -@patch("src.agents.generalist_agent.answer_user_question", new_callable=AsyncMock) -async def test_generalist_agent( - mock_answer_user_question, -): - mock_answer_user_question.return_value = json.dumps( - {"status": "success", "response": json.dumps({"is_valid": True, "answer": "Example summary."})} - ) - generalist_agent = GeneralistAgent("llm", "mock_model") - result = await generalist_agent.invoke("example query") - expected_response = {"content": "Example summary.", "ignore_validation": "false"} - assert json.loads(result) == expected_response +mock_model = "mockmodel" +mock_llm = get_llm("mockllm") @pytest.mark.asyncio -@patch("src.agents.generalist_agent.answer_user_question", new_callable=AsyncMock) -async def test_generalist_agent_reponse_format_error( - mock_answer_user_question, -): - mock_answer_user_question.return_value = json.dumps( - {"status": "success", "response": json.dumps({"is_valid": True, "answer_wrong_format": "Example summary."})} - ) - generalist_agent = GeneralistAgent("llm", "mock_model") +async def test_generalist_agent(mocker): + mock_llm.chat = mocker.AsyncMock(return_value="Example summary.") - result = await generalist_agent.invoke("example query") + agent = GeneralistAgent(llm_name="mockllm", model=mock_model) - expected_response = {"content": "Error in answer format.", "ignore_validation": "false"} + result = await agent.invoke("example query") + expected_response = {"content": "Example summary.", "ignore_validation": "false"} assert json.loads(result) == expected_response diff --git a/backend/tests/agents/materiality_agent_test.py b/backend/tests/agents/materiality_agent_test.py index 3c5bbe1b..9839772f 100644 --- a/backend/tests/agents/materiality_agent_test.py +++ b/backend/tests/agents/materiality_agent_test.py @@ -28,6 +28,6 @@ async def test_invoke_calls_llm(mocker): mock_llm.chat = mocker.AsyncMock(return_value=json.dumps(mock_selected_files)) mock_llm.chat_with_file = mocker.AsyncMock(return_value=json.dumps(mock_materiality_topics)) - response = await agent.list_material_topics("AstraZeneca") + response = await agent.list_material_topics_for_company("AstraZeneca") assert response == mock_materiality_topics["material_topics"] diff --git a/backend/tests/directors/report_director_test.py b/backend/tests/directors/report_director_test.py index b95bd450..d684eb0a 100644 --- a/backend/tests/directors/report_director_test.py +++ b/backend/tests/directors/report_director_test.py @@ -32,7 +32,7 @@ async def test_create_report_from_file(mocker): # Mock materiality agent mock_materiality_agent = mocker.AsyncMock() - mock_materiality_agent.list_material_topics.return_value = mock_topics + mock_materiality_agent.list_material_topics_for_company.return_value = mock_topics mocker.patch("src.directors.report_director.get_materiality_agent", return_value=mock_materiality_agent) mock_store_report = mocker.patch("src.directors.report_director.store_report", return_value=file_upload) @@ -46,7 +46,7 @@ async def test_create_report_from_file(mocker): mock_store_report.assert_called_once_with(expected_response) - mock_materiality_agent.list_material_topics.assert_called_once_with("CompanyABC") + mock_materiality_agent.list_material_topics_for_company.assert_called_once_with("CompanyABC") assert response == expected_response diff --git a/backend/tests/session/test_redis_session_middleware.py b/backend/tests/session/test_redis_session_middleware.py index 8b3422d9..46183772 100644 --- a/backend/tests/session/test_redis_session_middleware.py +++ b/backend/tests/session/test_redis_session_middleware.py @@ -74,5 +74,5 @@ def test_reset_session(mocker, mock_request_context): assert get_session("key2") == "value2" reset_session() - assert get_session("key1", None) is None - assert get_session("key2", None) is None + assert get_session("key1") == [] + assert get_session("key2") == []