Skip to content

Commit

Permalink
fix pytests
Browse files Browse the repository at this point in the history
  • Loading branch information
IMladjenovic committed Dec 12, 2024
1 parent 281c208 commit 568b03d
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 18 deletions.
8 changes: 8 additions & 0 deletions backend/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
# conftest.py
import pytest
import os
from unittest.mock import AsyncMock, patch

Check failure on line 4 in backend/conftest.py

View workflow job for this annotation

GitHub Actions / Linting Backend

Ruff (F401)

backend/conftest.py:4:27: F401 `unittest.mock.AsyncMock` imported but unused


@pytest.hookimpl(tryfirst=True)
def pytest_configure(config):
# Set an environment variable to indicate pytest is running
os.environ["PYTEST_RUNNING"] = "1"


@pytest.fixture(autouse=True)
def mock_async_openai():
with patch("src.llm.openai.AsyncOpenAI") as mock_async_openai:
yield mock_async_openai
3 changes: 1 addition & 2 deletions backend/src/llm/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@


class Mistral(LLM):
def __init__(self):
self.client = MistralApi(api_key=config.mistral_key)
client = MistralApi(api_key=config.mistral_key)

async def chat(self, model, system_prompt: str, user_prompt: str, return_json=False) -> str:
logger.debug("Called llm. Waiting on response model with prompt {0}.".format(str([system_prompt, user_prompt])))
Expand Down
3 changes: 1 addition & 2 deletions backend/src/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ def remove_citations(message: Text):


class OpenAI(LLM):
def __init__(self):
self.client = AsyncOpenAI(api_key=config.openai_key)
client = AsyncOpenAI(api_key=config.openai_key)

async def chat(self, model, system_prompt: str, user_prompt: str, return_json=False) -> str:
logger.debug(
Expand Down
25 changes: 11 additions & 14 deletions backend/tests/llm/openai_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@
from src.llm.openai import OpenAI


def mock_openai_object(id_value: str) -> AsyncMock:
mock_obj = AsyncMock()
mock_obj.id = id_value
return AsyncMock(return_value=mock_obj)
@dataclass
class MockResponse:
id: str


@dataclass
Expand Down Expand Up @@ -50,18 +49,16 @@ class MockListResponse:


@pytest.mark.asyncio
@patch("src.llm.openai.AsyncOpenAI")
@patch("src.llm.OpenAI.client") # Ensure this matches the import path in your module
async def test_chat_with_file_removes_citations(mock_client):
mock_instance = mock_client.return_value

mock_instance.files.create = mock_openai_object(id_value="file-id")
mock_instance.beta.assistants.create = mock_openai_object(id_value="assistant-id")
mock_instance.beta.threads.create = mock_openai_object(id_value="thread-id")
mock_instance.beta.threads.runs.create_and_poll = mock_openai_object(id_value="run-id")
mock_instance.beta.threads.messages.list = AsyncMock(return_value=MockListResponse)
mock_client.files.create = AsyncMock(return_value=MockResponse(id="file-id"))
mock_client.beta.assistants.create = AsyncMock(return_value=MockResponse(id="assistant-id"))
mock_client.beta.threads.create = AsyncMock(return_value=MockResponse(id="thread-id"))
mock_client.beta.threads.runs.create_and_poll = AsyncMock(return_value=MockResponse(id="run-id"))
mock_client.beta.threads.messages.list = AsyncMock(return_value=MockListResponse)

client = OpenAI()
response = await client.chat_with_file(
llm = OpenAI()
response = await llm.chat_with_file(
model="",
user_prompt="",
system_prompt="",
Expand Down

0 comments on commit 568b03d

Please sign in to comment.