Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tests: mock & integration tests for KG/Files SDK functionality #495

Merged
merged 1 commit into from
Jul 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
264 changes: 263 additions & 1 deletion tests/backend/test_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,33 @@
```
"""

from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch

import httpx
import pytest
from writer.ai import Conversation, WriterAIManager, complete, init, stream_complete
from writer.ai import (
Conversation,
File,
FileDeleteResponse,
Graph,
GraphDeleteResponse,
GraphRemoveFileFromGraphResponse,
SDKFile,
SDKGraph,
WriterAIManager,
complete,
create_graph,
delete_file,
delete_graph,
init,
list_files,
list_graphs,
retrieve_file,
retrieve_graph,
stream_complete,
upload_file,
)
from writerai import Writer
from writerai._streaming import Stream
from writerai.types import Chat, ChatStreamingData, Completion, StreamingData
Expand Down Expand Up @@ -116,6 +138,73 @@ def fake_response_content():
yield streaming_client


@pytest.fixture
def sdk_graph_mock():
return SDKGraph(
id="test_graph_id",
created_at=datetime.now(),
file_status={"completed": 0, "failed": 0, "in_progress": 0, "total": 0},
name="test_graph",
description="A test graph"
)


@pytest.fixture
def sdk_file_mock():
return SDKFile(
id="test_file_id",
created_at=datetime.now(),
graph_ids=["test_graph_id"],
name="test_file"
)


@pytest.fixture
def mock_graphs_accessor(sdk_file_mock, sdk_graph_mock):
with patch('writer.ai.Graph._retrieve_graphs_accessor') as mock_acquire_client:
mock_accessor = MagicMock()
mock_graph = Graph(sdk_graph_mock)
mock_file = File(sdk_file_mock)
mock_accessor.create.return_value = mock_graph
mock_accessor.add_file_to_graph.return_value = mock_file
mock_accessor.retrieve.return_value = mock_graph
mock_accessor.list.return_value = [mock_graph]
mock_accessor.delete.return_value = GraphDeleteResponse(id="test_file_id", deleted=True)
mock_accessor.remove_file_from_graph.return_value = GraphRemoveFileFromGraphResponse(id="test_file_id", deleted=True)

mock_acquire_client.return_value = mock_accessor
yield mock_accessor


@pytest.fixture
def mock_files_accessor(sdk_file_mock):
with patch('writer.ai.File._retrieve_files_accessor') as mock_acquire_client:
mock_accessor = MagicMock()
mock_file = File(sdk_file_mock)
mock_accessor.retrieve.return_value = mock_file
mock_accessor.list.return_value = [mock_file]
mock_accessor.upload.return_value = mock_file
mock_accessor.delete.return_value = FileDeleteResponse(id="test_delete", deleted=True)
mock_acquire_client.return_value = mock_accessor
yield mock_accessor


@pytest.fixture
def created_graphs():
graphs = []
yield graphs
for graph in graphs:
delete_graph(graph_id_or_graph=graph.id)


@pytest.fixture
def created_files():
files = []
yield files
for file in files:
delete_file(file_id_or_file=file.id)


class FakeAppProcessForAIManager:
def __init__(self, token):
self.ai_manager = WriterAIManager(token=token)
Expand Down Expand Up @@ -295,6 +384,92 @@ def test_init_writer_ai_manager(emulate_app_process):
assert manager.client.api_key == "fake_token"


def test_create_graph(mock_graphs_accessor):
graph = create_graph(name="test_graph", description="A test graph")

# As we modified SDK response, we expect this function
# to retrieve mock we prepared earlier
assert graph.id == "test_graph_id"
assert graph.name == "test_graph"


def test_retrieve_graph(mock_graphs_accessor):
graph = retrieve_graph(graph_id="test_id")

# As we modified SDK response, we expect this function
# to retrieve mock we prepared earlier
assert graph.id == "test_graph_id"
assert graph.name == "test_graph"


def test_list_graphs(mock_graphs_accessor):
graphs = list_graphs()

# As we modified SDK response, we expect this function
# to retrieve mock we prepared earlier
assert len(graphs) == 1
assert graphs[0].id == "test_graph_id"
assert graphs[0].name == "test_graph"


def test_delete_graph(mock_graphs_accessor):
response = delete_graph(graph_id_or_graph="test_graph_id")

# As we modified SDK response, we expect this function
# to retrieve mock we prepared earlier
assert response.deleted is True


def test_add_file_to_graph(mock_graphs_accessor, mock_files_accessor):
file = retrieve_file(file_id="test_file_id")
graph = retrieve_graph(graph_id="test_graph_id")
added_file = graph.add_file(file_id_or_file=file)

# As we modified SDK response, we expect this function
# to retrieve mock we prepared earlier
assert added_file.id == "test_file_id"
assert added_file.name == "test_file"
# Graph update should also trigger addition of its ID to stale_ids set
assert "test_graph_id" in Graph.stale_ids


def test_retrieve_file(mock_files_accessor):
file = retrieve_file(file_id="test_file_id")

# As we modified SDK response, we expect this function
# to retrieve mock we prepared earlier
assert file.id == "test_file_id"
assert file.name == "test_file"


def test_list_files(mock_files_accessor):
files = list_files()

# As we modified SDK response, we expect this function
# to retrieve mock we prepared earlier
assert len(files) == 1
assert files[0].id == "test_file_id"
assert files[0].name == "test_file"


def test_upload_file(mock_files_accessor):
data = b"file_content"
file = upload_file(data=data, type="text/plain", name="uploaded_file")

# As we modified SDK response, we expect this function
# to retrieve mock we prepared earlier
assert file.id == "test_file_id"
assert file.name == "test_file"


def test_delete_file(mock_files_accessor):
response = delete_file(file_id_or_file="test_file_id")

# As we modified SDK response, we expect this function
# to retrieve mock we prepared earlier
assert response.deleted is True


@explicit
def test_explicit_conversation_complete(emulate_app_process):
conversation = Conversation()
Expand Down Expand Up @@ -341,3 +516,90 @@ async def test_explicit_stream_complete(emulate_app_process):
assert isinstance(full_response, str)
assert len(full_response) > 0
assert "exercise" in full_response.lower()


@explicit
def test_explicit_create_graph(emulate_app_process, created_graphs):
graph = create_graph(name="integration_test_graph", description="Integration test graph")
created_graphs.append(graph)
assert graph.id is not None
assert graph.name == "integration_test_graph"


@explicit
def test_explicit_retrieve_graph(emulate_app_process, created_graphs):
created_graph = create_graph(name="integration_test_graph", description="Integration test graph")
created_graphs.append(created_graph)
graph = retrieve_graph(graph_id=created_graph.id)

assert graph.id == created_graph.id
assert graph.name == created_graph.name


@explicit
def test_explicit_list_graphs(emulate_app_process, created_graphs):
# Create a graph to ensure there's at least one graph in the list
graph = create_graph(name="integration_test_graph", description="Integration test graph")
created_graphs.append(graph)

graphs = list_graphs()

assert len(graphs) > 0
assert any(g.id == graph.id for g in graphs)


@explicit
def test_explicit_delete_graph(emulate_app_process, created_graphs):
created_graph = create_graph(name="integration_test_graph", description="Integration test graph")
created_graphs.append(created_graph)
response = delete_graph(graph_id_or_graph=created_graph.id)

# Ensure the graph is removed from created_graphs
created_graphs.remove(created_graph)

assert response.deleted is True


@explicit
def test_explicit_upload_file(emulate_app_process, created_files):
data = b"file_content"
file = upload_file(data=data, type="text/plain", name="integration_uploaded_file")
created_files.append(file)

assert file.id is not None
assert file.name == "integration_uploaded_file.txt"


@explicit
def test_explicit_retrieve_file(emulate_app_process, created_files):
uploaded_file = upload_file(data=b"file_content", type="text/plain", name="integration_uploaded_file")
created_files.append(uploaded_file)
file = retrieve_file(file_id=uploaded_file.id)

assert file.id == uploaded_file.id
assert file.name == uploaded_file.name


@explicit
def test_explicit_list_files(emulate_app_process, created_files):
# Upload a file to ensure there's at least one file in the list
data = b"file_content"
file = upload_file(data=data, type="text/plain", name="integration_uploaded_file")
created_files.append(file)

files = list_files()

assert len(files) > 0
assert any(f.id == file.id for f in files)


@explicit
def test_explicit_delete_file(emulate_app_process, created_files):
uploaded_file = upload_file(data=b"file_content", type="text/plain", name="integration_uploaded_file")
created_files.append(uploaded_file)
response = delete_file(file_id_or_file=uploaded_file.id)

# Ensure the file is removed from created_files
created_files.remove(uploaded_file)

assert response.deleted is True
Loading