diff --git a/backend/conftest.py b/backend/conftest.py index 64b912e1..d9ba2607 100644 --- a/backend/conftest.py +++ b/backend/conftest.py @@ -1,28 +1,9 @@ # conftest.py import pytest -import logging import os -from unittest.mock import patch - -from openai import AsyncOpenAI - -logger = logging.getLogger(__name__) @pytest.hookimpl(tryfirst=True) def pytest_configure(config): # Set an environment variable to indicate pytest is running os.environ["PYTEST_RUNNING"] = "1" - os.environ["OPENAI_API_KEY"] = "fake_key" - - -class MockOpenAI(AsyncOpenAI): - pass - - -@pytest.fixture(autouse=True) -@patch("src.llm.OpenAI.client") # Ensure this matches the import path in your module -def mock_async_openai(): - with patch("openai.AsyncOpenAI", MockOpenAI) as mock_async_openai: - yield mock_async_openai - print("AsyncOpenAI patched") diff --git a/backend/src/llm/openai.py b/backend/src/llm/openai.py index 461ea533..435eaa20 100644 --- a/backend/src/llm/openai.py +++ b/backend/src/llm/openai.py @@ -18,7 +18,6 @@ def remove_citations(message: Text): class OpenAI(LLM): - client = AsyncOpenAI(api_key=config.openai_key) async def chat(self, model, system_prompt: str, user_prompt: str, return_json=False) -> str: logger.debug( @@ -27,7 +26,8 @@ async def chat(self, model, system_prompt: str, user_prompt: str, return_json=Fa ) ) try: - response = await self.client.chat.completions.create( + client = AsyncOpenAI(api_key=config.openai_key) + response = await client.chat.completions.create( model=model, messages=[ {"role": "system", "content": system_prompt}, @@ -57,16 +57,17 @@ async def chat_with_file( files_by_path: Optional[list[LLMFileFromPath]] = None, files_by_stream: Optional[list[LLMFileFromBytes]] = None ) -> str: + client = AsyncOpenAI(api_key=config.openai_key) file_ids = await self.__upload_files(files_by_path, files_by_stream) - file_assistant = await self.client.beta.assistants.create( + file_assistant = await client.beta.assistants.create( name="ESG Analyst", instructions=system_prompt, model=model, tools=[{"type": "file_search"}], ) - thread = await self.client.beta.threads.create( + thread = await client.beta.threads.create( messages=[ { "role": "user", @@ -79,11 +80,11 @@ async def chat_with_file( ] ) - run = await self.client.beta.threads.runs.create_and_poll( + run = await client.beta.threads.runs.create_and_poll( thread_id=thread.id, assistant_id=file_assistant.id ) - messages = await self.client.beta.threads.messages.list(thread_id=thread.id, run_id=run.id) + messages = await client.beta.threads.messages.list(thread_id=thread.id, run_id=run.id) message = messages.data[0].content[0].text @@ -95,6 +96,7 @@ async def __upload_files( files_by_path: Optional[list[LLMFileFromPath]], files_by_stream: Optional[list[LLMFileFromBytes]] ) -> list[str]: + client = AsyncOpenAI(api_key=config.openai_key) if not files_by_path: files_by_path = [] if not files_by_stream: @@ -103,7 +105,7 @@ async def __upload_files( file_ids = [] for file in files_by_stream + files_by_path: logger.info(f"Uploading file '{file.file_name}' to OpenAI") - file = await self.client.files.create( + file = await client.files.create( file=file.file_path if isinstance(file, LLMFileFromPath) else file.file_stream, purpose="assistants" ) diff --git a/backend/tests/llm/openai_test.py b/backend/tests/llm/openai_test.py index cc10276a..1d230923 100644 --- a/backend/tests/llm/openai_test.py +++ b/backend/tests/llm/openai_test.py @@ -49,16 +49,18 @@ class MockListResponse: @pytest.mark.asyncio -@patch("src.llm.OpenAI.client") # Ensure this matches the import path in your module -async def test_chat_with_file_removes_citations(mock_client): - mock_client.files.create = AsyncMock(return_value=MockResponse(id="file-id")) - mock_client.beta.assistants.create = AsyncMock(return_value=MockResponse(id="assistant-id")) - mock_client.beta.threads.create = AsyncMock(return_value=MockResponse(id="thread-id")) - mock_client.beta.threads.runs.create_and_poll = AsyncMock(return_value=MockResponse(id="run-id")) - mock_client.beta.threads.messages.list = AsyncMock(return_value=MockListResponse) +@patch("openai.AsyncOpenAI") # Ensure this matches the import path in your module +async def test_chat_with_file_removes_citations(mock_async_openai): + mock_instance = mock_async_openai.return_value - llm = OpenAI() - response = await llm.chat_with_file( + mock_instance.files.create = AsyncMock(return_value=MockResponse(id="file-id")) + mock_instance.beta.assistants.create = AsyncMock(return_value=MockResponse(id="assistant-id")) + mock_instance.beta.threads.create = AsyncMock(return_value=MockResponse(id="thread-id")) + mock_instance.beta.threads.runs.create_and_poll = AsyncMock(return_value=MockResponse(id="run-id")) + mock_instance.beta.threads.messages.list = AsyncMock(return_value=MockListResponse) + + client = OpenAI() + response = await client .chat_with_file( model="", user_prompt="", system_prompt="",