From 47932bb366b60adb021c8c9132213dcbe1c9699b Mon Sep 17 00:00:00 2001 From: mmikita95 Date: Mon, 18 Nov 2024 13:21:55 +0300 Subject: [PATCH] chore: tests for AI module tool calls --- tests/backend/test_ai.py | 518 +++++++++++++++++++++++++++++++++++---- 1 file changed, 468 insertions(+), 50 deletions(-) diff --git a/tests/backend/test_ai.py b/tests/backend/test_ai.py index 23165a94d..819cbe942 100644 --- a/tests/backend/test_ai.py +++ b/tests/backend/test_ai.py @@ -59,6 +59,7 @@ apps, ask, complete, + create_function_tool, create_graph, delete_file, delete_graph, @@ -103,80 +104,166 @@ def mock_app_content_generation(): @pytest.fixture -def mock_non_streaming_client(): +def mock_writer_client(): + """Mock fixture for Writer client with configurable behavior.""" with patch('writer.ai.WriterAIManager.acquire_client') as mock_acquire_client: original_client = Writer(api_key="fake_token") - non_streaming_client = AsyncMock(original_client) - mock_acquire_client.return_value = non_streaming_client + mock_client = AsyncMock(original_client) + mock_acquire_client.return_value = mock_client + + # Attach the original client for use in tests + mock_client._original_client = original_client + + # Basic response mock + mock_client.completions.create.return_value = Completion( + choices=[{"text": "Completed text"}] + ) - non_streaming_client.chat.chat.return_value = \ - Chat( + def mock_chat_response(include_tool_calls=None): + """Configurable mock chat response.""" + tool_calls = include_tool_calls if include_tool_calls else None + return Chat( id="test", choices=[ { "finish_reason": "stop", "index": 0, - "logprobs": {}, "message": { "role": "assistant", "content": "Response", - "refusal": "false" - } - }, - + "tool_calls": tool_calls, + }, + } ], created=0, model="test", object="chat.completion" ) - non_streaming_client.completions.create.return_value = \ + mock_client.completions.create.return_value = \ Completion(choices=[{"text": test_complete_literal}]) - non_streaming_client.graphs.question.return_value = \ + mock_client.graphs.question.return_value = \ MagicMock(answer="Mocked Answer") - yield non_streaming_client + mock_client.chat.create_chat_response = mock_chat_response + + # Shared fake stream response generator + def fake_stream_response(): + yield b'data: {"choices":[{"message":{"content":"part1","role":"assistant"}}]}\n\n' + yield b'data: {"choices":[{"message":{"content":"part2","role":"assistant"}}]}\n\n' + + mock_client.fake_stream_response = fake_stream_response + yield mock_client @pytest.fixture -def mock_streaming_client(): - def fake_response_content(): - yield b'data: {"id":"test","choices":[{"finish_reason":"stop","message":{"content":"part1","role":"assistant"}}],"created":0,"model":"test"}\n\n' - yield b'data: {"id":"test","choices":[{"finish_reason":"stop","message":{"content":"part2","role":"assistant"}}],"created":0,"model":"test"}\n\n' - yield b'\n' - with patch('writer.ai.WriterAIManager.acquire_client') as mock_acquire_client: - original_client = Writer(api_key="fake_token") - streaming_client = AsyncMock(original_client) - mock_acquire_client.return_value = streaming_client - - mock_chat_stream = Stream( - client=original_client, - cast_to=ChatCompletionChunk, - response=httpx.Response( - status_code=200, - content=fake_response_content() - ) - ) - streaming_client.chat.chat.return_value = mock_chat_stream - - # Mock completion streaming - mock_completion_stream = MagicMock() - mock_completion_stream.__iter__.return_value = iter([ - StreamingData(value="part1"), - StreamingData(value=" part2") - ]) - streaming_client.completions.create.return_value = \ +def mock_non_streaming_client(mock_writer_client): + """Mock client without tool calls.""" + mock_writer_client.chat.chat.return_value = mock_writer_client.chat.create_chat_response() + yield mock_writer_client + + +@pytest.fixture +def mock_tool_calls_client(mock_writer_client): + """Mock client with tool calls returned only once.""" + tool_calls = [ + {"id": "1", "type": "function", "function": {"name": "test_function", "arguments": '{"arg1": 5}'}}, + {"id": "2", "type": "function", "function": {"name": "test_function", "arguments": '{"arg1": 7}'}}, + ] + + tool_calls_control = { + "check_enabled": True, # Toggle for enabling/disabling the check + "returned": False # Tracks if tool calls have been returned + } + def conditional_chat_response(*args, **kwargs): + # Return tool calls only on the first call + if tool_calls_control["returned"]: + return mock_writer_client.chat.create_chat_response() + else: + if tool_calls_control["check_enabled"]: + tool_calls_control["returned"] = True + return mock_writer_client.chat.create_chat_response( + include_tool_calls=mock_writer_client.tool_calls + ) + + # Apply the conditional response to the mock + mock_writer_client.tool_calls = tool_calls + mock_writer_client.tool_calls_control = tool_calls_control + mock_writer_client.chat.chat.side_effect = conditional_chat_response + yield mock_writer_client + + +@pytest.fixture +def mock_streaming_client(mock_writer_client): + """Mock client for streaming.""" + mock_stream = Stream( + client=mock_writer_client._original_client, + cast_to=Chat, + response=httpx.Response(200, content=mock_writer_client.fake_stream_response()), + ) + mock_writer_client.chat.chat.return_value = mock_stream + + mock_completion_stream = MagicMock() + mock_completion_stream.__iter__.return_value = iter([ + StreamingData(value="part1"), + StreamingData(value=" part2") + ]) + mock_writer_client.completions.create.return_value = \ mock_completion_stream - # Mock question streaming - mock_graph_stream = MagicMock() - mock_graph_stream._iter_events.return_value = iter([ - MagicMock(data='{"answer": "Part 1"}'), - MagicMock(data='{"answer": "Part 2"}'), - ]) - streaming_client.graphs.question.return_value = mock_graph_stream + # Mock question streaming + mock_graph_stream = MagicMock() + mock_graph_stream._iter_events.return_value = iter([ + MagicMock(data='{"answer": "Part 1"}'), + MagicMock(data='{"answer": "Part 2"}'), + ]) + mock_writer_client.graphs.question.return_value = mock_graph_stream + yield mock_writer_client - yield streaming_client + +@pytest.fixture +def mock_streaming_tool_calls_client(mock_writer_client): + """Mock client with tool calls returned only once.""" + tool_calls_control = { + "check_enabled": True, # Toggle for enabling/disabling the check + "returned": False # Tracks if tool calls have been returned + } + tool_call_stream = [ + b'data: {"id":"a2a302fa-a85c-44b4-9c20-0956d557517c","object":"chat.completion.chunk","choices":[{"index":0,"finish_reason":null,"delta":{"content":"","role":"assistant","tool_calls":null,"graph_data":{"sources":null,"status":null,"subqueries":null},"refusal":null},"logprobs":null}],"created":1731919162,"model":"palmyra-x-004","usage":null,"system_fingerprint":"v1","service_tier":null}\n\n', + b'data: {"id":"dc06043b-b002-40d3-b3fb-e6f1c2a090bc","object":"chat.completion.chunk","choices":[{"index":0,"finish_reason":null,"delta":{"content":null,"role":"assistant","tool_calls":[{"index":0,"id":"chatcmpl-tool-1ff1df7d81074e5995ec77af2911f7c1","type":"function","function":{"name":"test_function","arguments":null}}],"graph_data":{"sources":null,"status":null,"subqueries":null},"refusal":null},"logprobs":null}],"created":1731919162,"model":"palmyra-x-004","usage":null,"system_fingerprint":"v1","service_tier":null}\n\n', + b'data: {"id":"10d482f0-e370-41f6-9101-846d3ccbd3c6","object":"chat.completion.chunk","choices":[{"index":0,"finish_reason":null,"delta":{"content":null,"role":"assistant","tool_calls":[{"index":0,"id":null,"type":null,"function":{"name":null,"arguments":"{\\"arg1\\": 5"}}],"graph_data":{"sources":null,"status":null,"subqueries":null},"refusal":null},"logprobs":null}],"created":1731919162,"model":"palmyra-x-004","usage":null,"system_fingerprint":"v1","service_tier":null}\n\n', + b'data: {"id":"3eaf20c3-5af0-4b6e-a604-e6bd62bd24ee","object":"chat.completion.chunk","choices":[{"index":0,"finish_reason":null,"delta":{"content":null,"role":"assistant","tool_calls":[{"index":0,"id":null,"type":null,"function":{"name":null,"arguments":""}}],"graph_data":{"sources":null,"status":null,"subqueries":null},"refusal":null},"logprobs":null}],"created":1731919162,"model":"palmyra-x-004","usage":null,"system_fingerprint":"v1","service_tier":null}\n\n', + b'data: {"id":"433e565b-38e9-4d20-8c27-ffc33b252669","object":"chat.completion.chunk","choices":[{"index":0,"finish_reason":null,"delta":{"content":null,"role":"assistant","tool_calls":[{"index":0,"id":null,"type":null,"function":{"name":null,"arguments":"}"}}],"graph_data":{"sources":null,"status":null,"subqueries":null},"refusal":null},"logprobs":null}],"created":1731919162,"model":"palmyra-x-004","usage":null,"system_fingerprint":"v1","service_tier":null}\n\n', + b'data: {"id":"5091cdb6-2bed-4a55-9625-77c5cd607383","object":"chat.completion.chunk","choices":[{"index":0,"finish_reason":null,"delta":{"content":"\\n","role":"assistant","tool_calls":null,"graph_data":{"sources":null,"status":null,"subqueries":null},"refusal":null},"logprobs":null}],"created":1731919162,"model":"palmyra-x-004","usage":null,"system_fingerprint":"v1","service_tier":null}\n\n', + b'data: {"id":"b526eb17-0751-4d6b-8c61-332a28efbac3","object":"chat.completion.chunk","choices":[{"index":0,"finish_reason":null,"delta":{"content":null,"role":"assistant","tool_calls":[{"index":1,"id":"chatcmpl-tool-22538c865437437e8ace8076a5755749","type":"function","function":{"name":"test_function","arguments":null}}],"graph_data":{"sources":null,"status":null,"subqueries":null},"refusal":null},"logprobs":null}],"created":1731919162,"model":"palmyra-x-004","usage":null,"system_fingerprint":"v1","service_tier":null}\n\n', + b'data: {"id":"6cc43656-4ea1-43cc-8538-212476bc5681","object":"chat.completion.chunk","choices":[{"index":0,"finish_reason":null,"delta":{"content":null,"role":"assistant","tool_calls":[{"index":1,"id":null,"type":null,"function":{"name":null,"arguments":"{\\"arg1\\": 7"}}],"graph_data":{"sources":null,"status":null,"subqueries":null},"refusal":null},"logprobs":null}],"created":1731919162,"model":"palmyra-x-004","usage":null,"system_fingerprint":"v1","service_tier":null}\n\n', + b'data: {"id":"a68791cf-f01a-42fd-87bc-d209b2075e13","object":"chat.completion.chunk","choices":[{"index":0,"finish_reason":null,"delta":{"content":null,"role":"assistant","tool_calls":[{"index":1,"id":null,"type":null,"function":{"name":null,"arguments":""}}],"graph_data":{"sources":null,"status":null,"subqueries":null},"refusal":null},"logprobs":null}],"created":1731919162,"model":"palmyra-x-004","usage":null,"system_fingerprint":"v1","service_tier":null}\n\n', + b'data: {"id":"168bc578-1916-448b-8373-1c473383c0cb","object":"chat.completion.chunk","choices":[{"index":0,"finish_reason":"tool_calls","delta":{"content":null,"role":"assistant","tool_calls":[{"index":1,"id":null,"type":null,"function":{"name":null,"arguments":"}"}}],"graph_data":{"sources":null,"status":null,"subqueries":null},"refusal":null},"logprobs":null}],"created":1731919162,"model":"palmyra-x-004","usage":null,"system_fingerprint":"v1","service_tier":null}\n\n', + ] + def fake_stream_with_tool_calls(): + """Generate streaming data with tool calls in JSON format for the first chunk.""" + if tool_calls_control["returned"]: + # Use the original fake stream response + yield from mock_writer_client.fake_stream_response() + else: + if tool_calls_control["check_enabled"]: + tool_calls_control["returned"] = True + for chunk in mock_writer_client.tool_call_stream: + yield chunk + + # Mock the streaming response with the modified generator + def create_mock_stream(*args, **kwargs): + mock_stream = Stream( + client=mock_writer_client._original_client, + cast_to=Chat, + response=httpx.Response(200, content=fake_stream_with_tool_calls()), + ) + return mock_stream + + mock_writer_client.tool_call_stream = tool_call_stream + mock_writer_client.tool_calls_control = tool_calls_control + mock_writer_client.chat.chat.side_effect = create_mock_stream + yield mock_writer_client @pytest.fixture @@ -405,6 +492,147 @@ def test_conversation_stream_complete(emulate_app_process, mock_streaming_client assert " ".join(chunk["content"] for chunk in response_chunks) == "part1 part2" +@pytest.mark.set_token("fake_token") +def test_conversation_with_tool_call(emulate_app_process, mock_tool_calls_client): + def test_function(arg1): + return int(arg1) ** 2 + + conversation = Conversation() + conversation._register_callable(test_function, "test_function", {"arg1": {"type": "integer"}}) + + conversation.add("user", "Call a tool") + _ = conversation.complete(tools=[{ + "type": "function", + "callable": test_function, + "name": "test_function", + "parameters": {"arg1": {"type": "string"}} + }]) + tool_results = [message for message in conversation.messages if message["role"] == "tool"] + assert len(tool_results) == 2 + assert tool_results[0]["content"] == "25" + assert tool_results[1]["content"] == "49" + + +@pytest.mark.set_token("fake_token") +def test_conversation_with_stream_tool_call(emulate_app_process, mock_streaming_tool_calls_client): + def test_function(arg1): + return int(arg1) ** 2 + + conversation = Conversation() + conversation._register_callable(test_function, "test_function", {"arg1": {"type": "integer"}}) + + conversation.add("user", "Call a tool") + response = conversation.stream_complete(tools=[{ + "type": "function", + "callable": test_function, + "name": "test_function", + "parameters": {"arg1": {"type": "string"}} + }]) + # Initiate streaming to trigger the calls + for _ in response: + pass + tool_results = [message for message in conversation.messages if message["role"] == "tool"] + assert len(tool_results) == 2 + assert tool_results[0]["content"] == "25" + assert tool_results[1]["content"] == "49" + + +@pytest.mark.set_token("fake_token") +def test_conversation_with_bad_tool_call(emulate_app_process, mock_tool_calls_client): + def test_function(arg1): + return int(arg1) ** 2 + + # Prepare bad tool call data + mock_tool_calls_client.tool_calls = [ + {"id": "1", "type": "function", "function": {"name": "test_function", "arguments": '{"arg1": "invalid"}'}}, + {"id": "2", "type": "function", "function": {"name": "test_function", "arguments": '{"arg2": 5}'}}, # Missing 'arg1' + ] + + conversation = Conversation() + conversation._register_callable(test_function, "test_function", {"arg1": {"type": "integer"}}) + + conversation.add("user", "Call a tool") + + _ = conversation.complete(tools=[{ + "type": "function", + "callable": test_function, + "name": "test_function", + "parameters": {"arg1": {"type": "string"}} + }]) + tool_results = [message for message in conversation.messages if message["role"] == "tool"] + assert len(tool_results) == 2 + assert "inform the user about the error" in tool_results[0]["content"] + assert "inform the user about the error" in tool_results[1]["content"] + + +@pytest.mark.set_token("fake_token") +def test_conversation_with_stream_bad_tool_call(emulate_app_process, mock_streaming_tool_calls_client): + def test_function(arg1): + return int(arg1) ** 2 + + # Prepare bad tool call data + mock_streaming_tool_calls_client.tool_call_stream = [ + b'data: {"id":"1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":null,"role":"assistant","tool_calls":[{"index":0,"id":"1","type":"function","function":{"name":"test_function","arguments":"{\\"arg1\\":\\"invalid\\"}"}}]}}]}\n\n', + b'data: {"id":"2","object":"chat.completion.chunk","choices":[{"index":0,"finish_reason":"tool_calls","delta":{"content":null,"role":"assistant","tool_calls":[{"index":1,"id":"2","type":"function","function":{"name":"test_function","arguments":"{\\"arg2\\": 5}"}}]}}]}\n\n', # Missing 'arg1' + ] + + conversation = Conversation() + conversation._register_callable(test_function, "test_function", {"arg1": {"type": "integer"}}) + + conversation.add("user", "Call a tool") + + response = conversation.stream_complete(tools=[{ + "type": "function", + "callable": test_function, + "name": "test_function", + "parameters": {"arg1": {"type": "string"}} + }]) + # Initiate streaming to trigger the calls + for _ in response: + pass + tool_results = [message for message in conversation.messages if message["role"] == "tool"] + assert len(tool_results) == 2 + assert "inform the user about the error" in tool_results[0]["content"] + assert "inform the user about the error" in tool_results[1]["content"] + +@pytest.mark.set_token("fake_token") +def test_conversation_with_tool_call_max_depth(emulate_app_process, mock_tool_calls_client): + def test_function(arg1): + return int(arg1) ** 2 + mock_tool_calls_client.tool_calls_control["check_enabled"] = False + conversation = Conversation() + conversation._register_callable(test_function, "test_function", {"arg1": {"type": "integer"}}) + + conversation.add("user", "Call a tool") + with pytest.raises(RuntimeError): + _ = conversation.complete(tools=[{ + "type": "function", + "callable": test_function, + "name": "test_function", + "parameters": {"arg1": {"type": "string"}} + }]) + + +@pytest.mark.set_token("fake_token") +def test_conversation_with_stream_tool_call_max_depth(emulate_app_process, mock_streaming_tool_calls_client): + def test_function(arg1): + return int(arg1) ** 2 + mock_streaming_tool_calls_client.tool_calls_control["check_enabled"] = False + conversation = Conversation() + conversation._register_callable(test_function, "test_function", {"arg1": {"type": "integer"}}) + + conversation.add("user", "Call a tool") + with pytest.raises(RuntimeError): + response = conversation.stream_complete(tools=[{ + "type": "function", + "callable": test_function, + "name": "test_function", + "parameters": {"arg1": {"type": "string"}} + }]) + # Initiate streaming to trigger the calls + for _ in response: + pass + @pytest.mark.set_token("fake_token") def test_complete(emulate_app_process, mock_non_streaming_client): response = complete("test") @@ -614,6 +842,196 @@ def test_explicit_conversation_stream_complete(emulate_app_process): assert "engagement" in full_response.lower() +@explicit +def test_explicit_conversation_complete_tool_calls(emulate_app_process): + secret_word_one = "PARIS" + correct_password = "Sunshine" + secret_word_two = "Moonglow" + number = 15 + coefficient = 2.7 + + def test_function_one(): + return secret_word_one + + def test_function_two(password): + if password == correct_password: + return secret_word_two + + def test_function_three(number: int, coefficient: float): + return number * coefficient + + conversation = Conversation() + conversation.add( + "user", + "Use the function to retrieve the secret word. " + + "Name only the secret word and nothing else." + ) + + # Test a function with no parameters + first_response = conversation.complete( + tools=create_function_tool( + callable=test_function_one, + name="get_secret_word", + parameters=None, + description="A function to retrieve the secret word." + ) + ) + + assert first_response["role"] == "assistant" + assert first_response["content"] == secret_word_one + + conversation = Conversation() + conversation.add( + "user", + "Use the function to retrieve the secret word. " + + "Name only the secret word and nothing else. " + + f"The password is {correct_password}" + ) + # Test a function with a single parameter + second_response = conversation.complete( + tools=create_function_tool( + callable=test_function_two, + name="get_secret_word_by_password", + parameters={ + "password": { + "type": "string", + "description": "A password used to retrieve the secret word" + } + } + ) + ) + + assert second_response["role"] == "assistant" + assert second_response["content"] == secret_word_two + + conversation = Conversation() + conversation.add( + "user", + "Use the function to calculate the final number. " + + "Respond with only the resulting calculation and nothing else. " + + f"The number is {number}. The coefficient is {coefficient}." + ) + + # Test a function with two non-string parameters + third_response = conversation.complete( + tools=create_function_tool( + callable=test_function_three, + name="calculate", + parameters={ + "number": { + "type": "integer", + "description": "The base number to perform calculation against" + }, + "coefficient": { + "type": "float", + "description": "The coefficient to use against the number" + } + } + ) + ) + + assert third_response["role"] == "assistant" + assert third_response["content"] == str(number * coefficient) + + +@explicit +def test_explicit_conversation_stream_complete_tool_calls(emulate_app_process): + secret_word_one = "PARIS" + correct_password = "Sunshine" + secret_word_two = "Moonglow" + number = 15 + coefficient = 2.7 + + def test_function_one(): + return secret_word_one + + def test_function_two(password): + if password == correct_password: + return secret_word_two + + def test_function_three(number: int, coefficient: float): + return number * coefficient + + conversation = Conversation() + conversation.add( + "user", + "Use the function to retrieve the secret word. " + + "Name only the secret word and nothing else." + ) + + # Test a function with no parameters + first_response_stream = conversation.stream_complete( + tools=create_function_tool( + callable=test_function_one, + name="get_secret_word", + parameters=None, + description="A function to retrieve the secret word." + ) + ) + first_response = "" + for chunk in first_response_stream: + first_response += chunk.get("content") + + assert first_response == secret_word_one + + conversation = Conversation() + conversation.add( + "user", + "Use the function to retrieve the secret word. " + + "Name only the secret word and nothing else. " + + f"The password is {correct_password}" + ) + # Test a function with a single parameter + second_response_stream = conversation.stream_complete( + tools=create_function_tool( + callable=test_function_two, + name="get_secret_word_by_password", + parameters={ + "password": { + "type": "string", + "description": "A password used to retrieve the secret word" + } + } + ) + ) + second_response = "" + for chunk in second_response_stream: + second_response += chunk.get("content") + + assert second_response == secret_word_two + + conversation = Conversation() + conversation.add( + "user", + "Use the function to calculate the final number. " + + "Respond with only the resulting calculation and nothing else. " + + f"The number is {number}. The coefficient is {coefficient}." + ) + + # Test a function with two non-string parameters + third_response_stream = conversation.stream_complete( + tools=create_function_tool( + callable=test_function_three, + name="calculate", + parameters={ + "number": { + "type": "integer", + "description": "The base number to perform calculation against" + }, + "coefficient": { + "type": "float", + "description": "The coefficient to use against the number" + } + } + ) + ) + third_response = "" + for chunk in third_response_stream: + third_response += chunk.get("content") + + assert third_response == str(number * coefficient) + + @explicit @pytest.mark.asyncio async def test_explicit_complete(emulate_app_process): @@ -764,7 +1182,7 @@ def test_explicit_ask_graph_class( ) assert isinstance(answer, str) - assert answer == " PARIS" + assert answer == "PARIS" @explicit @@ -852,7 +1270,7 @@ def test_explicit_ask( ) assert isinstance(answer, str) - assert answer == " PARIS" + assert answer == "PARIS" @explicit