Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add steering concepts #213

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions aleph_alpha_client/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ class CompletionRequest:
return the optimized completion in the completion field of the CompletionResponse.
The raw completion, if returned, will contain the un-optimized completion.

steering_concepts (Optional[list[str]], default None)
Names of the steering vectors to apply on this task. This steers the output in the
direction given by positive examples, and away from negative examples if provided.

Examples:
>>> prompt = Prompt.from_text("Provide a short description of AI:")
>>> request = CompletionRequest(prompt=prompt, maximum_tokens=20)
Expand Down Expand Up @@ -215,6 +219,7 @@ class CompletionRequest:
control_log_additive: Optional[bool] = True
repetition_penalties_include_completion: bool = True
raw_completion: bool = False
steering_concepts: Optional[List[str]] = None

def to_json(self) -> Mapping[str, Any]:
payload = {k: v for k, v in self._asdict().items() if v is not None}
Expand Down
28 changes: 28 additions & 0 deletions tests/test_complete.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,31 @@ def test_num_tokens_generated_with_best_of(sync_client: Client, model_name: str)
number_tokens_completion = len(completion_result.completion_tokens)

assert response.num_tokens_generated == best_of * number_tokens_completion


@pytest.mark.system_test
def test_steering_completion(sync_client: Client, chat_model_name: str):
base_request = CompletionRequest(
prompt=Prompt.from_text(
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nParaphrase the following phrase. You are an honest man.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
),
maximum_tokens=16,
)

steered_request = CompletionRequest(
prompt=Prompt.from_text(
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nParaphrase the following phrase. You are an honest man.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
),
steering_concepts=["shakespeare"],
maximum_tokens=16,
)

base_response = sync_client.complete(base_request, model=chat_model_name)
steered_response = sync_client.complete(steered_request, model=chat_model_name)

base_completion_result = base_response.completions[0].completion
steered_completion_result = steered_response.completions[0].completion

assert base_completion_result
assert steered_completion_result
assert base_completion_result != steered_completion_result
Loading