diff --git a/aleph_alpha_client/completion.py b/aleph_alpha_client/completion.py index 5a04b82..95d5b40 100644 --- a/aleph_alpha_client/completion.py +++ b/aleph_alpha_client/completion.py @@ -177,6 +177,10 @@ class CompletionRequest: return the optimized completion in the completion field of the CompletionResponse. The raw completion, if returned, will contain the un-optimized completion. + steering_concepts (Optional[list[str]], default None) + Names of the steering vectors to apply on this task. This steers the output in the + direction given by positive examples, and away from negative examples if provided. + Examples: >>> prompt = Prompt.from_text("Provide a short description of AI:") >>> request = CompletionRequest(prompt=prompt, maximum_tokens=20) @@ -215,6 +219,7 @@ class CompletionRequest: control_log_additive: Optional[bool] = True repetition_penalties_include_completion: bool = True raw_completion: bool = False + steering_concepts: Optional[List[str]] = None def to_json(self) -> Mapping[str, Any]: payload = {k: v for k, v in self._asdict().items() if v is not None} diff --git a/tests/test_complete.py b/tests/test_complete.py index cc5eca3..3b37f7c 100644 --- a/tests/test_complete.py +++ b/tests/test_complete.py @@ -188,3 +188,30 @@ def test_num_tokens_generated_with_best_of(sync_client: Client, model_name: str) number_tokens_completion = len(completion_result.completion_tokens) assert response.num_tokens_generated == best_of * number_tokens_completion + + +def test_steering_completion(sync_client: Client, chat_model_name: str): + base_request = CompletionRequest( + prompt=Prompt.from_text( + "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nParaphrase the following phrase. You are an honest man.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ), + maximum_tokens=16, + ) + + steered_request = CompletionRequest( + prompt=Prompt.from_text( + "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nParaphrase the following phrase. You are an honest man.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ), + steering_concepts=["shakespeare"], + maximum_tokens=16, + ) + + base_response = sync_client.complete(base_request, model=chat_model_name) + steered_response = sync_client.complete(steered_request, model=chat_model_name) + + base_completion_result = base_response.completions[0].completion + steered_completion_result = steered_response.completions[0].completion + + assert base_completion_result + assert steered_completion_result + assert base_completion_result != steered_completion_result