diff --git a/tests/backend/test_ai.py b/tests/backend/test_ai.py index 0828062e2..b569f1a22 100644 --- a/tests/backend/test_ai.py +++ b/tests/backend/test_ai.py @@ -40,11 +40,33 @@ ``` """ +from datetime import datetime from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest -from writer.ai import Conversation, WriterAIManager, complete, init, stream_complete +from writer.ai import ( + Conversation, + File, + FileDeleteResponse, + Graph, + GraphDeleteResponse, + GraphRemoveFileFromGraphResponse, + SDKFile, + SDKGraph, + WriterAIManager, + complete, + create_graph, + delete_file, + delete_graph, + init, + list_files, + list_graphs, + retrieve_file, + retrieve_graph, + stream_complete, + upload_file, +) from writerai import Writer from writerai._streaming import Stream from writerai.types import Chat, ChatStreamingData, Completion, StreamingData @@ -116,6 +138,73 @@ def fake_response_content(): yield streaming_client +@pytest.fixture +def sdk_graph_mock(): + return SDKGraph( + id="test_graph_id", + created_at=datetime.now(), + file_status={"completed": 0, "failed": 0, "in_progress": 0, "total": 0}, + name="test_graph", + description="A test graph" + ) + + +@pytest.fixture +def sdk_file_mock(): + return SDKFile( + id="test_file_id", + created_at=datetime.now(), + graph_ids=["test_graph_id"], + name="test_file" + ) + + +@pytest.fixture +def mock_graphs_accessor(sdk_file_mock, sdk_graph_mock): + with patch('writer.ai.Graph._retrieve_graphs_accessor') as mock_acquire_client: + mock_accessor = MagicMock() + mock_graph = Graph(sdk_graph_mock) + mock_file = File(sdk_file_mock) + mock_accessor.create.return_value = mock_graph + mock_accessor.add_file_to_graph.return_value = mock_file + mock_accessor.retrieve.return_value = mock_graph + mock_accessor.list.return_value = [mock_graph] + mock_accessor.delete.return_value = GraphDeleteResponse(id="test_file_id", deleted=True) + mock_accessor.remove_file_from_graph.return_value = GraphRemoveFileFromGraphResponse(id="test_file_id", deleted=True) + + mock_acquire_client.return_value = mock_accessor + yield mock_accessor + + +@pytest.fixture +def mock_files_accessor(sdk_file_mock): + with patch('writer.ai.File._retrieve_files_accessor') as mock_acquire_client: + mock_accessor = MagicMock() + mock_file = File(sdk_file_mock) + mock_accessor.retrieve.return_value = mock_file + mock_accessor.list.return_value = [mock_file] + mock_accessor.upload.return_value = mock_file + mock_accessor.delete.return_value = FileDeleteResponse(id="test_delete", deleted=True) + mock_acquire_client.return_value = mock_accessor + yield mock_accessor + + +@pytest.fixture +def created_graphs(): + graphs = [] + yield graphs + for graph in graphs: + delete_graph(graph_id_or_graph=graph.id) + + +@pytest.fixture +def created_files(): + files = [] + yield files + for file in files: + delete_file(file_id_or_file=file.id) + + class FakeAppProcessForAIManager: def __init__(self, token): self.ai_manager = WriterAIManager(token=token) @@ -295,6 +384,92 @@ def test_init_writer_ai_manager(emulate_app_process): assert manager.client.api_key == "fake_token" +def test_create_graph(mock_graphs_accessor): + graph = create_graph(name="test_graph", description="A test graph") + + # As we modified SDK response, we expect this function + # to retrieve mock we prepared earlier + assert graph.id == "test_graph_id" + assert graph.name == "test_graph" + + +def test_retrieve_graph(mock_graphs_accessor): + graph = retrieve_graph(graph_id="test_id") + + # As we modified SDK response, we expect this function + # to retrieve mock we prepared earlier + assert graph.id == "test_graph_id" + assert graph.name == "test_graph" + + +def test_list_graphs(mock_graphs_accessor): + graphs = list_graphs() + + # As we modified SDK response, we expect this function + # to retrieve mock we prepared earlier + assert len(graphs) == 1 + assert graphs[0].id == "test_graph_id" + assert graphs[0].name == "test_graph" + + +def test_delete_graph(mock_graphs_accessor): + response = delete_graph(graph_id_or_graph="test_graph_id") + + # As we modified SDK response, we expect this function + # to retrieve mock we prepared earlier + assert response.deleted is True + + +def test_add_file_to_graph(mock_graphs_accessor, mock_files_accessor): + file = retrieve_file(file_id="test_file_id") + graph = retrieve_graph(graph_id="test_graph_id") + added_file = graph.add_file(file_id_or_file=file) + + # As we modified SDK response, we expect this function + # to retrieve mock we prepared earlier + assert added_file.id == "test_file_id" + assert added_file.name == "test_file" + # Graph update should also trigger addition of its ID to stale_ids set + assert "test_graph_id" in Graph.stale_ids + + +def test_retrieve_file(mock_files_accessor): + file = retrieve_file(file_id="test_file_id") + + # As we modified SDK response, we expect this function + # to retrieve mock we prepared earlier + assert file.id == "test_file_id" + assert file.name == "test_file" + + +def test_list_files(mock_files_accessor): + files = list_files() + + # As we modified SDK response, we expect this function + # to retrieve mock we prepared earlier + assert len(files) == 1 + assert files[0].id == "test_file_id" + assert files[0].name == "test_file" + + +def test_upload_file(mock_files_accessor): + data = b"file_content" + file = upload_file(data=data, type="text/plain", name="uploaded_file") + + # As we modified SDK response, we expect this function + # to retrieve mock we prepared earlier + assert file.id == "test_file_id" + assert file.name == "test_file" + + +def test_delete_file(mock_files_accessor): + response = delete_file(file_id_or_file="test_file_id") + + # As we modified SDK response, we expect this function + # to retrieve mock we prepared earlier + assert response.deleted is True + + @explicit def test_explicit_conversation_complete(emulate_app_process): conversation = Conversation() @@ -341,3 +516,90 @@ async def test_explicit_stream_complete(emulate_app_process): assert isinstance(full_response, str) assert len(full_response) > 0 assert "exercise" in full_response.lower() + + +@explicit +def test_explicit_create_graph(emulate_app_process, created_graphs): + graph = create_graph(name="integration_test_graph", description="Integration test graph") + created_graphs.append(graph) + assert graph.id is not None + assert graph.name == "integration_test_graph" + + +@explicit +def test_explicit_retrieve_graph(emulate_app_process, created_graphs): + created_graph = create_graph(name="integration_test_graph", description="Integration test graph") + created_graphs.append(created_graph) + graph = retrieve_graph(graph_id=created_graph.id) + + assert graph.id == created_graph.id + assert graph.name == created_graph.name + + +@explicit +def test_explicit_list_graphs(emulate_app_process, created_graphs): + # Create a graph to ensure there's at least one graph in the list + graph = create_graph(name="integration_test_graph", description="Integration test graph") + created_graphs.append(graph) + + graphs = list_graphs() + + assert len(graphs) > 0 + assert any(g.id == graph.id for g in graphs) + + +@explicit +def test_explicit_delete_graph(emulate_app_process, created_graphs): + created_graph = create_graph(name="integration_test_graph", description="Integration test graph") + created_graphs.append(created_graph) + response = delete_graph(graph_id_or_graph=created_graph.id) + + # Ensure the graph is removed from created_graphs + created_graphs.remove(created_graph) + + assert response.deleted is True + + +@explicit +def test_explicit_upload_file(emulate_app_process, created_files): + data = b"file_content" + file = upload_file(data=data, type="text/plain", name="integration_uploaded_file") + created_files.append(file) + + assert file.id is not None + assert file.name == "integration_uploaded_file.txt" + + +@explicit +def test_explicit_retrieve_file(emulate_app_process, created_files): + uploaded_file = upload_file(data=b"file_content", type="text/plain", name="integration_uploaded_file") + created_files.append(uploaded_file) + file = retrieve_file(file_id=uploaded_file.id) + + assert file.id == uploaded_file.id + assert file.name == uploaded_file.name + + +@explicit +def test_explicit_list_files(emulate_app_process, created_files): + # Upload a file to ensure there's at least one file in the list + data = b"file_content" + file = upload_file(data=data, type="text/plain", name="integration_uploaded_file") + created_files.append(file) + + files = list_files() + + assert len(files) > 0 + assert any(f.id == file.id for f in files) + + +@explicit +def test_explicit_delete_file(emulate_app_process, created_files): + uploaded_file = upload_file(data=b"file_content", type="text/plain", name="integration_uploaded_file") + created_files.append(uploaded_file) + response = delete_file(file_id_or_file=uploaded_file.id) + + # Ensure the file is removed from created_files + created_files.remove(uploaded_file) + + assert response.deleted is True