diff --git a/tests/test_complete.py b/tests/test_complete.py index 41db92a..885175e 100644 --- a/tests/test_complete.py +++ b/tests/test_complete.py @@ -192,7 +192,14 @@ def test_num_tokens_generated_with_best_of(sync_client: Client, model_name: str) @pytest.mark.system_test def test_steering_completion(sync_client: Client, chat_model_name: str): - request = CompletionRequest( + base_request = CompletionRequest( + prompt=Prompt.from_text( + "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nParaphrase the following phrase. You are an honest man.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ), + maximum_tokens=16, + ) + + steered_request = CompletionRequest( prompt=Prompt.from_text( "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nParaphrase the following phrase. You are an honest man.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" ), @@ -200,9 +207,12 @@ def test_steering_completion(sync_client: Client, chat_model_name: str): maximum_tokens=16, ) - response = sync_client.complete(request, model=chat_model_name) - completion_result = response.completions[0] - assert completion_result.completion is not None - assert ( - "art" in completion_result.completion - ), "Steered completion should contain Shakespearean language like 'art' for this particular phrase." + base_response = sync_client.complete(base_request, model=chat_model_name) + steered_response = sync_client.complete(steered_request, model=chat_model_name) + + base_completion_result = base_response.completions[0].completion + steered_completion_result = steered_response.completions[0].completion + + assert base_completion_result + assert steered_completion_result + assert base_completion_result != steered_completion_result