Skip to content

Commit

Permalink
Merge branch 'streamsync-cloud:dev' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
anant-writer authored May 28, 2024
2 parents 86c9970 + 001c193 commit 3cb2fd1
Show file tree
Hide file tree
Showing 4 changed files with 368 additions and 2 deletions.
3 changes: 3 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[pytest]
markers =
explicit: mark a test to be run only when explicitly specified
7 changes: 5 additions & 2 deletions src/writer/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,14 @@ def acquire_instance(cls) -> 'WriterAIManager':
:returns: The current instance of the manager.
"""
instance: WriterAIManager
current_process = get_app_process()

# If instance was not created explicitly, we initialize a new one
instance: WriterAIManager = \
getattr(current_process, 'ai_manager', cls())
try:
instance = getattr(current_process, 'ai_manager')
except AttributeError:
instance = cls()
return instance

@classmethod
Expand Down
343 changes: 343 additions & 0 deletions tests/backend/test_ai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,343 @@
"""
# AI Module Test Suite
This module provides a suite of tests for the AI integration with the Writer SDK.
## Types of tests
1. **Mock tests**
- These tests simulate interactions with the Writer AI SDK without making real API calls.
- They are faster and can be run frequently during development.
2. **SDK query tests**
- These tests make real API calls to the Writer AI service.
- They are intended for use on potentially breaking changes and major releases to ensure compatibility with the live API.
## Running the tests
By default, SDK query tests are marked with the custom `explicit` decorator and are excluded from regular test runs.
Only mock tests are run by regular `pytest` command:
```sh
pytest ./tests/backend/test_ai.py
```
To run SDK query tests, ensure you have `pytest-env` installed:
```sh
pip install pytest-env
```
Then, set the `WRITER_API_KEY` environment variable in `pytest.ini` file:
```ini
[pytest]
env =
WRITER_API_KEY=your_api_key_here
```
After that, to include SDK query tests into the run, use the `--full-run` option:
```sh
pytest ./tests/backend/test_ai.py --full-run
```
"""

from unittest.mock import AsyncMock, MagicMock, patch

import httpx
import pytest
from writer.ai import Conversation, WriterAIManager, complete, init, stream_complete
from writerai import Writer
from writerai._streaming import Stream
from writerai.types import Chat, ChatStreamingData, Completion, StreamingData

# Decorator to mark tests as explicit, i.e. that they only to be run on direct demand
explicit = pytest.mark.explicit


test_complete_literal = "Completed text"


@pytest.fixture
def mock_non_streaming_client():
with patch('writer.ai.WriterAIManager.acquire_client') as mock_acquire_client:
original_client = Writer(api_key="fake_token")
non_streaming_client = AsyncMock(original_client)
mock_acquire_client.return_value = non_streaming_client

non_streaming_client.chat.chat.return_value = \
Chat(
id="test",
choices=[
{
"finish_reason": "stop",
"message": {
"role": "assistant",
"content": "Response"
}
}
],
created=0,
model="test"
)
non_streaming_client.completions.create.return_value = \
Completion(choices=[{"text": test_complete_literal}])

yield non_streaming_client


@pytest.fixture
def mock_streaming_client():
def fake_response_content():
yield b'data: {"id":"test","choices":[{"finish_reason":"stop","message":{"content":"part1","role":"assistant"}}],"created":0,"model":"test"}\n\n'
yield b'data: {"id":"test","choices":[{"finish_reason":"stop","message":{"content":"part2","role":"assistant"}}],"created":0,"model":"test"}\n\n'
yield b'\n'
with patch('writer.ai.WriterAIManager.acquire_client') as mock_acquire_client:
original_client = Writer(api_key="fake_token")
streaming_client = AsyncMock(original_client)
mock_acquire_client.return_value = streaming_client

mock_chat_stream = Stream(
client=original_client,
cast_to=ChatStreamingData,
response=httpx.Response(
status_code=200,
content=fake_response_content()
)
)
streaming_client.chat.chat.return_value = mock_chat_stream

# Mock completion streaming
mock_completion_stream = MagicMock()
mock_completion_stream.__iter__.return_value = iter([
StreamingData(value="part1"),
StreamingData(value=" part2")
])
streaming_client.completions.create.return_value = mock_completion_stream

yield streaming_client


class FakeAppProcessForAIManager:
def __init__(self, token):
self.ai_manager = WriterAIManager(token=token)


def create_fake_app_process(token: str) -> FakeAppProcessForAIManager:
"""
Helper function to create and patch FakeAppProcessForAIManager with a given token.
"""
fake_process = FakeAppProcessForAIManager(token)
method_to_patch = 'writer.ai.WriterAIManager.acquire_instance'
patcher = patch(method_to_patch, return_value=fake_process.ai_manager)
patcher.start()
return fake_process


@pytest.fixture
def emulate_app_process(request):
token = None
marker = request.node.get_closest_marker('set_token')
if marker:
token = marker.args[0]
with patch('writer.ai.get_app_process') as mock_get_app_process:
fake_process = create_fake_app_process(token)
mock_get_app_process.return_value = fake_process
yield fake_process
patch.stopall()


def test_conversation_init_with_prompt():
# Initialize with a system prompt
prompt = "You are a social media expert in the financial industry"
conversation = Conversation(prompt)

assert len(conversation.messages) == 1
assert conversation.messages[0] == {
"role": "system",
"content": prompt,
"actions": None
}


def test_conversation_init_with_history():
# Initialize with a history of messages
history = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi, how can I help?"}
]
conversation = Conversation(history)

assert len(conversation.messages) == len(history)
for i, message in enumerate(history):
assert conversation.messages[i] == {**message, "actions": None}


def test_conversation_add_message():
# Initialize a conversation and add a message
conversation = Conversation("Initial prompt")

conversation.add("user", "Hello")

assert len(conversation.messages) == 2
assert conversation.messages[1] == {
"role": "user",
"content": "Hello",
"actions": None
}


def test_conversation_add_chunk_to_last_message():
# Initialize a conversation and add chunks to the last message
conversation = Conversation("Initial prompt")

# Add initial message from the user
conversation.add("user", "Hello")

# Add a chunk
chunk1 = {"content": "How", "chunk": True}
conversation += chunk1

# Add another chunk
chunk2 = {"content": " are you?", "chunk": True}
conversation += chunk2

assert len(conversation.messages) == 2
assert conversation.messages[1] == {
"role": "user",
"content": "HelloHow are you?",
"actions": None
}


def test_conversation_validate_message():
# Test message validation
valid_message = {"role": "user", "content": "Hello"}
invalid_message_no_role = {"content": "Hello"}
invalid_message_no_content = {"role": "user"}
invalid_message_wrong_role = {"role": "unknown", "content": "Hello"}
invalid_message_non_dict = "Invalid message"

# This should pass without exceptions
Conversation.validate_message(valid_message)

# These should raise ValueError
with pytest.raises(ValueError):
Conversation.validate_message(invalid_message_no_role)
with pytest.raises(ValueError):
Conversation.validate_message(invalid_message_no_content)
with pytest.raises(ValueError):
Conversation.validate_message(invalid_message_wrong_role)
with pytest.raises(ValueError):
Conversation.validate_message(invalid_message_non_dict)


def test_conversation_serialized_messages_excludes_system():
# Initialize with a mix of system and non-system messages
history = [
{"role": "system", "content": "System message"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi, how can I help?"},
{"role": "system", "content": "Another system message"}
]
conversation = Conversation(history)

serialized_messages = conversation.serialized_messages

# Ensure system messages are excluded
assert len(serialized_messages) == 2
assert all(message["role"] != "system" for message in serialized_messages)
assert serialized_messages[0] == \
{"role": "user", "content": "Hello", "actions": None}
assert serialized_messages[1] == \
{"role": "assistant", "content": "Hi, how can I help?", "actions": None}


@pytest.mark.set_token("fake_token")
@pytest.mark.asyncio
def test_conversation_complete(emulate_app_process, mock_non_streaming_client):
conversation = Conversation()
response = conversation.complete()

assert response["role"] == "assistant"
assert response["content"] == "Response"


@pytest.mark.set_token("fake_token")
def test_conversation_stream_complete(emulate_app_process, mock_streaming_client):
# Create the Conversation object and collect the response chunks
conversation = Conversation("Initial prompt")

response_chunks = []
for chunk in conversation.stream_complete():
response_chunks.append(chunk)

# Verify the content
assert " ".join(chunk["content"] for chunk in response_chunks) == "part1 part2"


@pytest.mark.set_token("fake_token")
def test_complete(emulate_app_process, mock_non_streaming_client):
response = complete("test")

assert response == test_complete_literal


@pytest.mark.set_token("fake_token")
def test_stream_complete(emulate_app_process, mock_streaming_client):
response_chunks = list(stream_complete("test"))

assert "".join(response_chunks) == "part1 part2"


@pytest.mark.set_token("fake_token")
def test_init_writer_ai_manager(emulate_app_process):
manager = init("fake_token")
assert isinstance(manager, WriterAIManager)
assert manager.client.api_key == "fake_token"


@explicit
def test_explicit_conversation_complete(emulate_app_process):
conversation = Conversation()
conversation.add("user", "Hello, how can I improve my social media engagement?")

response = conversation.complete()

assert response["role"] == "assistant"
assert "engagement" in response["content"].lower()


@explicit
def test_explicit_conversation_stream_complete(emulate_app_process):
conversation = Conversation()
conversation.add("user", "Hello, how can I improve my social media engagement?")

response_chunks = []
for chunk in conversation.stream_complete():
response_chunks.append(chunk)

full_response = " ".join(chunk["content"] for chunk in response_chunks)
assert "engagement" in full_response.lower()


@explicit
@pytest.mark.asyncio
async def test_explicit_complete(emulate_app_process):
initial_text = "Write a short paragraph about the benefits of regular exercise."
response = complete(initial_text)

assert isinstance(response, str)
assert len(response) > 0
assert "exercise" in response.lower()


@explicit
@pytest.mark.asyncio
async def test_explicit_stream_complete(emulate_app_process):
initial_text = "Write a short paragraph about the benefits of regular exercise."

response_chunks = list(stream_complete(initial_text))

full_response = "".join(response_chunks)
assert isinstance(full_response, str)
assert len(full_response) > 0
assert "exercise" in full_response.lower()
17 changes: 17 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
def pytest_collection_modifyitems(config, items):
if not config.getoption("--full-run"):
deselected = []
selected = []
for item in items:
if "explicit" in item.keywords:
deselected.append(item)
else:
selected.append(item)
items[:] = selected
config.hook.pytest_deselected(items=deselected)


def pytest_addoption(parser):
parser.addoption(
"--full-run", action="store_true", default=False, help="Include explicit-marked tests in the run (those are exluded from regular runs)"
)

0 comments on commit 3cb2fd1

Please sign in to comment.