-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add documentation about TextControl (#1106)
* feat: Documentation for TextControls TASK: PHS-729 Co-authored-by: Niklas Köhnecke <[email protected]>
- Loading branch information
1 parent
6e67df4
commit cb4803a
Showing
3 changed files
with
312 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
296 changes: 296 additions & 0 deletions
296
src/documentation/attention_manipulation_with_text_controls.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,296 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "0", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from aleph_alpha_client import TextControl\n", | ||
"from dotenv import load_dotenv\n", | ||
"\n", | ||
"from intelligence_layer.core import CompleteInput, Llama3InstructModel, NoOpTracer\n", | ||
"\n", | ||
"load_dotenv()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "1", | ||
"metadata": {}, | ||
"source": [ | ||
"# Attention Manipulation with `TextControl`\n", | ||
"\n", | ||
"`TextControl`s enable us to increase or decrease the attention of our model on different parts of the prompt (attention manipulation, AtMan).\n", | ||
"This can be convenient for influencing the model's behavior and priorities or for understanding why a model generates a given completion.\n", | ||
"\n", | ||
"Note: This notebook is quite sensitive to small changes in the model's behavior. The output of the model's might change slightly. We will therefore give the expected output in the form of comments, so you can compare your actual output to it. The basic message principles of the notebook of course should still hold.\n", | ||
"\n", | ||
"First, we define the instruction and input of our model and run it without any AtMan." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "2", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"instruction_text = (\n", | ||
" \"Say 'Hello' in one of the following languages. Say nothing else:\\nLanguages: \"\n", | ||
")\n", | ||
"input_text = \"Japanese and German\"\n", | ||
"\n", | ||
"llama_3_model = (\n", | ||
" Llama3InstructModel()\n", | ||
") # `TextControl` is only supported for `InstructModel`\n", | ||
"prompt_with_controls = llama_3_model.to_instruct_prompt(\n", | ||
" instruction=instruction_text,\n", | ||
" input=input_text,\n", | ||
")\n", | ||
"\n", | ||
"complete_input = CompleteInput(prompt=prompt_with_controls)\n", | ||
"output = llama_3_model.complete(complete_input, NoOpTracer())\n", | ||
"\n", | ||
"print(output.completion)\n", | ||
"####### Expected Output #######\n", | ||
"# Konnichiwa\n", | ||
"# Hallo" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "3", | ||
"metadata": {}, | ||
"source": [ | ||
"As you can see, the model does not comply with the \"one\" part of our instruction and gives us both translations. Let's fix this behavior." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "4", | ||
"metadata": {}, | ||
"source": [ | ||
"## Manipulating the Attention on the Instruction\n", | ||
"To make the model only give us one translation, we increase the focus of the model on the word \"one\" . To this end, we create the `TextControl` for the instruction as follows: " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "5", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"highlight_of_one = \"one\"\n", | ||
"instruct_controls = [\n", | ||
" TextControl(\n", | ||
" start=instruction_text.index(\n", | ||
" highlight_of_one\n", | ||
" ), # Be careful to get the correct index\n", | ||
" length=len(highlight_of_one),\n", | ||
" factor=1.1, # Increase focus on \"one\" by 10%\n", | ||
" )\n", | ||
"]\n", | ||
"\n", | ||
"prompt_with_controls = llama_3_model.to_instruct_prompt(\n", | ||
" instruction=instruction_text,\n", | ||
" input=input_text,\n", | ||
" instruction_controls=instruct_controls,\n", | ||
")\n", | ||
"\n", | ||
"complete_input = CompleteInput(prompt=prompt_with_controls)\n", | ||
"output = llama_3_model.complete(complete_input, NoOpTracer())\n", | ||
"\n", | ||
"print(output.completion)\n", | ||
"\n", | ||
"####### Expected Output #######\n", | ||
"# Konnichiwa\n", | ||
"# Hallo" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "6", | ||
"metadata": {}, | ||
"source": [ | ||
"So this did not work. This is because we only increased the weight of the focus with the `factor` '1.1'. A `factor` of '1' would have no effect at all, and as it seems, an increase by 10% doesn't do the trick. So lets' increase it ten-fold. " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "7", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"instruct_controls += [\n", | ||
" TextControl(\n", | ||
" start=instruction_text.index(highlight_of_one),\n", | ||
" length=len(highlight_of_one),\n", | ||
" factor=10,\n", | ||
" )\n", | ||
"]\n", | ||
"\n", | ||
"prompt_with_controls = llama_3_model.to_instruct_prompt(\n", | ||
" instruction=instruction_text,\n", | ||
" input=input_text,\n", | ||
" instruction_controls=instruct_controls,\n", | ||
")\n", | ||
"\n", | ||
"complete_input = CompleteInput(prompt=prompt_with_controls)\n", | ||
"output = llama_3_model.complete(complete_input, NoOpTracer())\n", | ||
"\n", | ||
"print(output.completion)\n", | ||
"\n", | ||
"####### Expected Output #######\n", | ||
"# Konnichiwa" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "8", | ||
"metadata": {}, | ||
"source": [ | ||
"Finally, the model listens to the restriction. But what if we *also* want the model to be a bit less concise? " | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "9", | ||
"metadata": {}, | ||
"source": [ | ||
"### Using Multiple `TextControls`\n", | ||
"We can apply multiple `TextControl`s for to different parts of our instruction. We can use this to only get one translation and a less concise answer: " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "10", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"highlight_of_conciseness = \". Say nothing else\"\n", | ||
"instruct_controls = [\n", | ||
" TextControl(\n", | ||
" start=instruction_text.index(highlight_of_one),\n", | ||
" length=len(highlight_of_one),\n", | ||
" factor=10,\n", | ||
" ),\n", | ||
" TextControl(\n", | ||
" start=instruction_text.index(highlight_of_conciseness),\n", | ||
" length=len(highlight_of_conciseness),\n", | ||
" factor=0.25,\n", | ||
" ),\n", | ||
"]\n", | ||
"\n", | ||
"prompt_with_controls = llama_3_model.to_instruct_prompt(\n", | ||
" instruction=instruction_text,\n", | ||
" input=input_text,\n", | ||
" instruction_controls=instruct_controls,\n", | ||
")\n", | ||
"\n", | ||
"complete_input = CompleteInput(prompt=prompt_with_controls)\n", | ||
"output = llama_3_model.complete(complete_input, NoOpTracer())\n", | ||
"\n", | ||
"print(output.completion)\n", | ||
"\n", | ||
"####### Expected Output #######\n", | ||
"# Konnichiwa (Japanese)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "11", | ||
"metadata": {}, | ||
"source": [ | ||
"Feel free to experiment with the `factor` parameters of the `TextControl`s to see how the output changes. You will notice that are some sweet spots that change the output for the better or worse." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "12", | ||
"metadata": {}, | ||
"source": [ | ||
"## Manipulating the Attention on the Input\n", | ||
"We can also manipulate the attention on different parts of the input instead of the instruction. The procedure is the same, but we use the parameter `input_controls` of `to_instruct_prompt()`:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "13", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"instruct_controls = [\n", | ||
" TextControl(\n", | ||
" start=instruction_text.index(highlight_of_one),\n", | ||
" length=len(highlight_of_one),\n", | ||
" factor=10,\n", | ||
" ),\n", | ||
" TextControl(\n", | ||
" start=instruction_text.index(highlight_of_conciseness),\n", | ||
" length=len(highlight_of_conciseness),\n", | ||
" factor=0.3, # Notice, how we need to tweak this up a bit to get only one answer\n", | ||
" ),\n", | ||
"]\n", | ||
"\n", | ||
"highlight_of_language = \"German\"\n", | ||
"input_controls = [\n", | ||
" TextControl(\n", | ||
" start=input_text.index(highlight_of_language),\n", | ||
" length=len(highlight_of_language),\n", | ||
" factor=10,\n", | ||
" )\n", | ||
"]\n", | ||
"\n", | ||
"prompt_with_controls = llama_3_model.to_instruct_prompt(\n", | ||
" instruction=instruction_text,\n", | ||
" input=input_text,\n", | ||
" instruction_controls=instruct_controls,\n", | ||
" input_controls=input_controls,\n", | ||
")\n", | ||
"\n", | ||
"complete_input = CompleteInput(prompt=prompt_with_controls)\n", | ||
"output = llama_3_model.complete(complete_input, NoOpTracer())\n", | ||
"\n", | ||
"print(output.completion)\n", | ||
"\n", | ||
"####### Expected Output #######\n", | ||
"# Hallo (German)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "14", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 2 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython2", | ||
"version": "2.7.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |