Skip to content

Commit

Permalink
More general functionality testing for steering
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloiyu committed Dec 19, 2024
1 parent 72ea0d1 commit 743e801
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions tests/test_complete.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,17 +192,27 @@ def test_num_tokens_generated_with_best_of(sync_client: Client, model_name: str)

@pytest.mark.system_test
def test_steering_completion(sync_client: Client, chat_model_name: str):
request = CompletionRequest(
base_request = CompletionRequest(
prompt=Prompt.from_text(
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nParaphrase the following phrase. You are an honest man.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
),
maximum_tokens=16,
)

steered_request = CompletionRequest(
prompt=Prompt.from_text(
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nParaphrase the following phrase. You are an honest man.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
),
steering_concepts=["shakespeare"],
maximum_tokens=16,
)

response = sync_client.complete(request, model=chat_model_name)
completion_result = response.completions[0]
assert completion_result.completion is not None
assert (
"art" in completion_result.completion
), "Steered completion should contain Shakespearean language like 'art' for this particular phrase."
base_response = sync_client.complete(base_request, model=chat_model_name)
steered_response = sync_client.complete(steered_request, model=chat_model_name)

base_completion_result = base_response.completions[0].completion
steered_completion_result = steered_response.completions[0].completion

assert base_completion_result
assert steered_completion_result
assert base_completion_result != steered_completion_result

0 comments on commit 743e801

Please sign in to comment.