Skip to content

Commit

Permalink
Merge pull request #495 from mmikita95/tests-graph-files
Browse files Browse the repository at this point in the history
tests: mock & integration tests for KG/Files SDK functionality
  • Loading branch information
ramedina86 authored Jul 23, 2024
2 parents 9c4730b + 122f0aa commit 86f28ed
Showing 1 changed file with 263 additions and 1 deletion.
264 changes: 263 additions & 1 deletion tests/backend/test_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,33 @@
```
"""

from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch

import httpx
import pytest
from writer.ai import Conversation, WriterAIManager, complete, init, stream_complete
from writer.ai import (
Conversation,
File,
FileDeleteResponse,
Graph,
GraphDeleteResponse,
GraphRemoveFileFromGraphResponse,
SDKFile,
SDKGraph,
WriterAIManager,
complete,
create_graph,
delete_file,
delete_graph,
init,
list_files,
list_graphs,
retrieve_file,
retrieve_graph,
stream_complete,
upload_file,
)
from writerai import Writer
from writerai._streaming import Stream
from writerai.types import Chat, ChatStreamingData, Completion, StreamingData
Expand Down Expand Up @@ -116,6 +138,73 @@ def fake_response_content():
yield streaming_client


@pytest.fixture
def sdk_graph_mock():
return SDKGraph(
id="test_graph_id",
created_at=datetime.now(),
file_status={"completed": 0, "failed": 0, "in_progress": 0, "total": 0},
name="test_graph",
description="A test graph"
)


@pytest.fixture
def sdk_file_mock():
return SDKFile(
id="test_file_id",
created_at=datetime.now(),
graph_ids=["test_graph_id"],
name="test_file"
)


@pytest.fixture
def mock_graphs_accessor(sdk_file_mock, sdk_graph_mock):
with patch('writer.ai.Graph._retrieve_graphs_accessor') as mock_acquire_client:
mock_accessor = MagicMock()
mock_graph = Graph(sdk_graph_mock)
mock_file = File(sdk_file_mock)
mock_accessor.create.return_value = mock_graph
mock_accessor.add_file_to_graph.return_value = mock_file
mock_accessor.retrieve.return_value = mock_graph
mock_accessor.list.return_value = [mock_graph]
mock_accessor.delete.return_value = GraphDeleteResponse(id="test_file_id", deleted=True)
mock_accessor.remove_file_from_graph.return_value = GraphRemoveFileFromGraphResponse(id="test_file_id", deleted=True)

mock_acquire_client.return_value = mock_accessor
yield mock_accessor


@pytest.fixture
def mock_files_accessor(sdk_file_mock):
with patch('writer.ai.File._retrieve_files_accessor') as mock_acquire_client:
mock_accessor = MagicMock()
mock_file = File(sdk_file_mock)
mock_accessor.retrieve.return_value = mock_file
mock_accessor.list.return_value = [mock_file]
mock_accessor.upload.return_value = mock_file
mock_accessor.delete.return_value = FileDeleteResponse(id="test_delete", deleted=True)
mock_acquire_client.return_value = mock_accessor
yield mock_accessor


@pytest.fixture
def created_graphs():
graphs = []
yield graphs
for graph in graphs:
delete_graph(graph_id_or_graph=graph.id)


@pytest.fixture
def created_files():
files = []
yield files
for file in files:
delete_file(file_id_or_file=file.id)


class FakeAppProcessForAIManager:
def __init__(self, token):
self.ai_manager = WriterAIManager(token=token)
Expand Down Expand Up @@ -295,6 +384,92 @@ def test_init_writer_ai_manager(emulate_app_process):
assert manager.client.api_key == "fake_token"


def test_create_graph(mock_graphs_accessor):
graph = create_graph(name="test_graph", description="A test graph")

# As we modified SDK response, we expect this function
# to retrieve mock we prepared earlier
assert graph.id == "test_graph_id"
assert graph.name == "test_graph"


def test_retrieve_graph(mock_graphs_accessor):
graph = retrieve_graph(graph_id="test_id")

# As we modified SDK response, we expect this function
# to retrieve mock we prepared earlier
assert graph.id == "test_graph_id"
assert graph.name == "test_graph"


def test_list_graphs(mock_graphs_accessor):
graphs = list_graphs()

# As we modified SDK response, we expect this function
# to retrieve mock we prepared earlier
assert len(graphs) == 1
assert graphs[0].id == "test_graph_id"
assert graphs[0].name == "test_graph"


def test_delete_graph(mock_graphs_accessor):
response = delete_graph(graph_id_or_graph="test_graph_id")

# As we modified SDK response, we expect this function
# to retrieve mock we prepared earlier
assert response.deleted is True


def test_add_file_to_graph(mock_graphs_accessor, mock_files_accessor):
file = retrieve_file(file_id="test_file_id")
graph = retrieve_graph(graph_id="test_graph_id")
added_file = graph.add_file(file_id_or_file=file)

# As we modified SDK response, we expect this function
# to retrieve mock we prepared earlier
assert added_file.id == "test_file_id"
assert added_file.name == "test_file"
# Graph update should also trigger addition of its ID to stale_ids set
assert "test_graph_id" in Graph.stale_ids


def test_retrieve_file(mock_files_accessor):
file = retrieve_file(file_id="test_file_id")

# As we modified SDK response, we expect this function
# to retrieve mock we prepared earlier
assert file.id == "test_file_id"
assert file.name == "test_file"


def test_list_files(mock_files_accessor):
files = list_files()

# As we modified SDK response, we expect this function
# to retrieve mock we prepared earlier
assert len(files) == 1
assert files[0].id == "test_file_id"
assert files[0].name == "test_file"


def test_upload_file(mock_files_accessor):
data = b"file_content"
file = upload_file(data=data, type="text/plain", name="uploaded_file")

# As we modified SDK response, we expect this function
# to retrieve mock we prepared earlier
assert file.id == "test_file_id"
assert file.name == "test_file"


def test_delete_file(mock_files_accessor):
response = delete_file(file_id_or_file="test_file_id")

# As we modified SDK response, we expect this function
# to retrieve mock we prepared earlier
assert response.deleted is True


@explicit
def test_explicit_conversation_complete(emulate_app_process):
conversation = Conversation()
Expand Down Expand Up @@ -341,3 +516,90 @@ async def test_explicit_stream_complete(emulate_app_process):
assert isinstance(full_response, str)
assert len(full_response) > 0
assert "exercise" in full_response.lower()


@explicit
def test_explicit_create_graph(emulate_app_process, created_graphs):
graph = create_graph(name="integration_test_graph", description="Integration test graph")
created_graphs.append(graph)
assert graph.id is not None
assert graph.name == "integration_test_graph"


@explicit
def test_explicit_retrieve_graph(emulate_app_process, created_graphs):
created_graph = create_graph(name="integration_test_graph", description="Integration test graph")
created_graphs.append(created_graph)
graph = retrieve_graph(graph_id=created_graph.id)

assert graph.id == created_graph.id
assert graph.name == created_graph.name


@explicit
def test_explicit_list_graphs(emulate_app_process, created_graphs):
# Create a graph to ensure there's at least one graph in the list
graph = create_graph(name="integration_test_graph", description="Integration test graph")
created_graphs.append(graph)

graphs = list_graphs()

assert len(graphs) > 0
assert any(g.id == graph.id for g in graphs)


@explicit
def test_explicit_delete_graph(emulate_app_process, created_graphs):
created_graph = create_graph(name="integration_test_graph", description="Integration test graph")
created_graphs.append(created_graph)
response = delete_graph(graph_id_or_graph=created_graph.id)

# Ensure the graph is removed from created_graphs
created_graphs.remove(created_graph)

assert response.deleted is True


@explicit
def test_explicit_upload_file(emulate_app_process, created_files):
data = b"file_content"
file = upload_file(data=data, type="text/plain", name="integration_uploaded_file")
created_files.append(file)

assert file.id is not None
assert file.name == "integration_uploaded_file.txt"


@explicit
def test_explicit_retrieve_file(emulate_app_process, created_files):
uploaded_file = upload_file(data=b"file_content", type="text/plain", name="integration_uploaded_file")
created_files.append(uploaded_file)
file = retrieve_file(file_id=uploaded_file.id)

assert file.id == uploaded_file.id
assert file.name == uploaded_file.name


@explicit
def test_explicit_list_files(emulate_app_process, created_files):
# Upload a file to ensure there's at least one file in the list
data = b"file_content"
file = upload_file(data=data, type="text/plain", name="integration_uploaded_file")
created_files.append(file)

files = list_files()

assert len(files) > 0
assert any(f.id == file.id for f in files)


@explicit
def test_explicit_delete_file(emulate_app_process, created_files):
uploaded_file = upload_file(data=b"file_content", type="text/plain", name="integration_uploaded_file")
created_files.append(uploaded_file)
response = delete_file(file_id_or_file=uploaded_file.id)

# Ensure the file is removed from created_files
created_files.remove(uploaded_file)

assert response.deleted is True

0 comments on commit 86f28ed

Please sign in to comment.