Skip to content

Commit

Permalink
fake openai key
Browse files Browse the repository at this point in the history
  • Loading branch information
IMladjenovic committed Dec 12, 2024
1 parent eebd1e6 commit d5030bd
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 35 deletions.
19 changes: 0 additions & 19 deletions backend/conftest.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,9 @@
# conftest.py
import pytest
import logging
import os
from unittest.mock import patch

from openai import AsyncOpenAI

logger = logging.getLogger(__name__)


@pytest.hookimpl(tryfirst=True)
def pytest_configure(config):
# Set an environment variable to indicate pytest is running
os.environ["PYTEST_RUNNING"] = "1"
os.environ["OPENAI_API_KEY"] = "fake_key"


class MockOpenAI(AsyncOpenAI):
pass


@pytest.fixture(autouse=True)
@patch("src.llm.OpenAI.client") # Ensure this matches the import path in your module
def mock_async_openai():
with patch("openai.AsyncOpenAI", MockOpenAI) as mock_async_openai:
yield mock_async_openai
print("AsyncOpenAI patched")
16 changes: 9 additions & 7 deletions backend/src/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def remove_citations(message: Text):


class OpenAI(LLM):
client = AsyncOpenAI(api_key=config.openai_key)

async def chat(self, model, system_prompt: str, user_prompt: str, return_json=False) -> str:
logger.debug(
Expand All @@ -27,7 +26,8 @@ async def chat(self, model, system_prompt: str, user_prompt: str, return_json=Fa
)
)
try:
response = await self.client.chat.completions.create(
client = AsyncOpenAI(api_key=config.openai_key)
response = await client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": system_prompt},
Expand Down Expand Up @@ -57,16 +57,17 @@ async def chat_with_file(
files_by_path: Optional[list[LLMFileFromPath]] = None,
files_by_stream: Optional[list[LLMFileFromBytes]] = None
) -> str:
client = AsyncOpenAI(api_key=config.openai_key)
file_ids = await self.__upload_files(files_by_path, files_by_stream)

file_assistant = await self.client.beta.assistants.create(
file_assistant = await client.beta.assistants.create(
name="ESG Analyst",
instructions=system_prompt,
model=model,
tools=[{"type": "file_search"}],
)

thread = await self.client.beta.threads.create(
thread = await client.beta.threads.create(
messages=[
{
"role": "user",
Expand All @@ -79,11 +80,11 @@ async def chat_with_file(
]
)

run = await self.client.beta.threads.runs.create_and_poll(
run = await client.beta.threads.runs.create_and_poll(
thread_id=thread.id, assistant_id=file_assistant.id
)

messages = await self.client.beta.threads.messages.list(thread_id=thread.id, run_id=run.id)
messages = await client.beta.threads.messages.list(thread_id=thread.id, run_id=run.id)

message = messages.data[0].content[0].text

Expand All @@ -95,6 +96,7 @@ async def __upload_files(
files_by_path: Optional[list[LLMFileFromPath]],
files_by_stream: Optional[list[LLMFileFromBytes]]
) -> list[str]:
client = AsyncOpenAI(api_key=config.openai_key)
if not files_by_path:
files_by_path = []
if not files_by_stream:
Expand All @@ -103,7 +105,7 @@ async def __upload_files(
file_ids = []
for file in files_by_stream + files_by_path:
logger.info(f"Uploading file '{file.file_name}' to OpenAI")
file = await self.client.files.create(
file = await client.files.create(
file=file.file_path if isinstance(file, LLMFileFromPath) else file.file_stream,
purpose="assistants"
)
Expand Down
20 changes: 11 additions & 9 deletions backend/tests/llm/openai_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,18 @@ class MockListResponse:


@pytest.mark.asyncio
@patch("src.llm.OpenAI.client") # Ensure this matches the import path in your module
async def test_chat_with_file_removes_citations(mock_client):
mock_client.files.create = AsyncMock(return_value=MockResponse(id="file-id"))
mock_client.beta.assistants.create = AsyncMock(return_value=MockResponse(id="assistant-id"))
mock_client.beta.threads.create = AsyncMock(return_value=MockResponse(id="thread-id"))
mock_client.beta.threads.runs.create_and_poll = AsyncMock(return_value=MockResponse(id="run-id"))
mock_client.beta.threads.messages.list = AsyncMock(return_value=MockListResponse)
@patch("openai.AsyncOpenAI") # Ensure this matches the import path in your module
async def test_chat_with_file_removes_citations(mock_async_openai):
mock_instance = mock_async_openai.return_value

llm = OpenAI()
response = await llm.chat_with_file(
mock_instance.files.create = AsyncMock(return_value=MockResponse(id="file-id"))
mock_instance.beta.assistants.create = AsyncMock(return_value=MockResponse(id="assistant-id"))
mock_instance.beta.threads.create = AsyncMock(return_value=MockResponse(id="thread-id"))
mock_instance.beta.threads.runs.create_and_poll = AsyncMock(return_value=MockResponse(id="run-id"))
mock_instance.beta.threads.messages.list = AsyncMock(return_value=MockListResponse)

client = OpenAI()
response = await client .chat_with_file(
model="",
user_prompt="",
system_prompt="",
Expand Down

0 comments on commit d5030bd

Please sign in to comment.