Skip to content

Commit

Permalink
Merge pull request #95 from stanfordnlp/zen/llamachatdemo
Browse files Browse the repository at this point in the history
[Minor] Support ITI Paper Results (#68)
  • Loading branch information
frankaging authored Jan 27, 2024
2 parents 91de11f + 3dc3b59 commit d5beda8
Show file tree
Hide file tree
Showing 8 changed files with 321 additions and 46 deletions.
File renamed without changes.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ logs
_hidden_local_dev*
tmp_dir_new*/
tmp/
tmp_*/
notebooks/figures/
*.tsv
*.csv
Expand Down
33 changes: 28 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import pyvene as pv
_, tokenizer, gpt2 = pv.create_gpt2()

pv_gpt2 = pv.IntervenableModel({
"layer": 0, "component": "block_output",
"source_representation": torch.zeros(gpt2.config.n_embd)
}, model=gpt2)

Expand All @@ -47,15 +48,37 @@ tensor([[[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[ 0.0483, -0.1212, -0.2816, ..., 0.1958, 0.0830, 0.0784],
[ 0.0519, 0.2547, -0.1631, ..., 0.0050, -0.0453, -0.1624]]])
```
You can share your interventions through Huggingface with others with a single call,

## _IntervenableModel_ Loaded from HuggingFace Directly
The following codeblock can reproduce [honest_llama-2 chat](https://github.com/likenneth/honest_llama/tree/master) from the paper [Inference-Time Intervention: Eliciting Truthful Answers from a Language Model](https://arxiv.org/abs/2306.03341). The added activations are only **~0.14MB** on disk!

```python
pv_gpt2.save(
save_directory="./your_gpt2_mounting_point/",
save_to_hf_hub=True,
hf_repo_name="your_gpt2_mounting_point"
# others can download from huggingface and use it directly
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import pyvene as pv

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-chat-hf",
torch_dtype=torch.bfloat16,
).to("cuda")

pv_model = pv.IntervenableModel.load(
"zhengxuanzenwu/intervenable_honest_llama2_chat_7B", # the activation diff ~0.14MB
model,
)

print("llama-2-chat loaded with interventions:")
q = "What's a cure for insomnia that always works?"
prompt = tokenizer(q, return_tensors="pt").to("cuda")
_, iti_response_shared = pv_model.generate(prompt, max_new_tokens=64, do_sample=False)
print(tokenizer.decode(iti_response_shared[0], skip_special_tokens=True))
```
With this, once you discover some clever intervention schemes, you can share with others quickly without sharing the actual base LMs or the intervention code!


## _IntervenableModel_ as Regular _nn.Module_
You can also use the `pv_gpt2` just like a regular torch model component inside another model, or another pipeline as,
```py
import torch
Expand Down
2 changes: 2 additions & 0 deletions pyvene/models/configuration_intervenable_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
model_type=None, # deprecating
# hidden fields for backlog
intervention_dimensions=None,
intervention_constant_sources=None,
**kwargs,
):
if not isinstance(representations, list):
Expand Down Expand Up @@ -77,6 +78,7 @@ def __init__(
self.mode = mode
self.sorted_keys = sorted_keys
self.intervention_dimensions = intervention_dimensions
self.intervention_constant_sources = intervention_constant_sources
self.model_type = model_type
super().__init__(**kwargs)

Expand Down
83 changes: 45 additions & 38 deletions pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,13 @@ def __init__(self, config, model, **kwargs):
else intervention_type[i]
)
all_metadata = representation._asdict()
all_metadata["embed_dim"] = get_dimension_by_component(
component_dim = get_dimension_by_component(
get_internal_model_type(model), model.config,
representation.component
) * int(representation.max_number_of_units)
)
if component_dim is not None:
component_dim *= int(representation.max_number_of_units)
all_metadata["embed_dim"] = component_dim
all_metadata["use_fast"] = self.use_fast
intervention = intervention_function(
**all_metadata
Expand Down Expand Up @@ -234,10 +237,14 @@ def _get_representation_key(self, representation):
Provide unique key for each intervention
"""
l = representation.layer
r = representation.component
c = representation.component
u = representation.unit
n = representation.max_number_of_units
key_proposal = f"layer.{l}.repr.{r}.unit.{u}.nunit.{n}"
if "." in c:
# string access for sure
key_proposal = f"comp.{c}.unit.{u}.nunit.{n}"
else:
key_proposal = f"layer.{l}.comp.{c}.unit.{u}.nunit.{n}"
if key_proposal not in self._key_collision_counter:
self._key_collision_counter[key_proposal] = 0
else:
Expand Down Expand Up @@ -337,8 +344,7 @@ def set_device(self, device):
Set device of interventions and the model
"""
for k, v in self.interventions.items():
if isinstance(v[0], TrainableIntervention):
v[0].to(device)
v[0].to(device)
self.model.to(device)

def get_device(self):
Expand Down Expand Up @@ -393,6 +399,7 @@ def save(
)
saving_config.intervention_types = []
saving_config.intervention_dimensions = []
saving_config.intervention_constant_sources = []

# handle constant source reprs if passed in.
serialized_representations = []
Expand All @@ -408,6 +415,8 @@ def save(
serialized_reprs[k] = None
elif k == "intervention_type":
serialized_reprs[k] = None
elif k == "intervention":
serialized_reprs[k] = None
else:
serialized_reprs[k] = v
serialized_representations += [
Expand All @@ -423,7 +432,7 @@ def save(
# save intervention binary file
if isinstance(intervention, TrainableIntervention) or \
intervention.source_representation is not None:
logging.warn(f"Saving trainable intervention to {binary_filename}.")
# logging.info(f"Saving trainable intervention to {binary_filename}.")
torch.save(
intervention.state_dict(),
os.path.join(save_directory, binary_filename),
Expand All @@ -433,8 +442,8 @@ def save(
try:
api.create_repo(hf_repo_name)
except:
logging.warn(
f"Skipping creating the repo since "
logging.info(
f"Uploading: {binary_filename}, but skipping creating the repo since "
f"either {hf_repo_name} exists or having authentication error."
)
api.upload_file(
Expand All @@ -443,17 +452,21 @@ def save(
repo_id=hf_repo_name,
repo_type="model",
)
saving_config.intervention_dimensions += [intervention.interchange_dim.tolist()]

if intervention.interchange_dim is None:
saving_config.intervention_dimensions += [None]
else:
saving_config.intervention_dimensions += [intervention.interchange_dim.tolist()]
saving_config.intervention_constant_sources += [intervention.is_source_constant]

# save metadata config
saving_config.save_pretrained(save_directory)
if save_to_hf_hub:
# push to huggingface hub
try:
api.create_repo(hf_repo_name)
except:
logging.warn(
f"Skipping creating the repo since "
logging.info(
f"Uploading the config, Skipping creating the repo since "
f"either {hf_repo_name} exists or having authentication error."
)
api.upload_file(
Expand All @@ -469,19 +482,13 @@ def load(load_directory, model, local_directory=None, from_huggingface_hub=False
Load interventions from disk or hub
"""
if not os.path.exists(load_directory) or from_huggingface_hub:
if local_directory is None:
raise ValueError(
"You have to provide local_directory to save hf files."
)
from huggingface_hub import hf_hub_download

hf_hub_download(
from_huggingface_hub = True

from huggingface_hub import snapshot_download
load_directory = snapshot_download(
repo_id=load_directory,
filename="config.json",
cache_dir=local_directory,
local_dir=local_directory,
)
# simple overwrite
load_directory = local_directory

# load config
saving_config = IntervenableConfig.from_pretrained(load_directory)
Expand All @@ -506,23 +513,23 @@ def load(load_directory, model, local_directory=None, from_huggingface_hub=False
for i, (k, v) in enumerate(intervenable.interventions.items()):
intervention = v[0]
binary_filename = f"intkey_{k}.bin"
if isinstance(intervention, TrainableIntervention) or \
(intervention.is_source_constant and \
not isinstance(intervention, SourcelessIntervention)):
if not os.path.exists(load_directory) or from_huggingface_hub:
hf_hub_download(
repo_id=load_directory,
filename=binary_filename,
cache_dir=local_directory,
)
logging.warn(f"Loading trainable intervention from {binary_filename}.")
if intervention.is_source_constant and not isinstance(intervention, ZeroIntervention):
saved_state_dict = torch.load(os.path.join(load_directory, binary_filename))
intervention.is_source_constant = \
saving_config.intervention_constant_sources[i]
intervention.set_interchange_dim(saving_config.intervention_dimensions[i])
if saving_config.intervention_constant_sources[i] and \
not isinstance(intervention, ZeroIntervention) and \
not isinstance(intervention, SourcelessIntervention):
# logging.warn(f"Loading trainable intervention from {binary_filename}.")
saved_state_dict = torch.load(os.path.join(load_directory, binary_filename))
try:
intervention.register_buffer(
'source_representation', saved_state_dict['source_representation']
)
intervention.load_state_dict(saved_state_dict)
intervention.set_interchange_dim(saving_config.intervention_dimensions[i])
except:
intervention.source_representation = saved_state_dict['source_representation']
elif isinstance(intervention, TrainableIntervention):
saved_state_dict = torch.load(os.path.join(load_directory, binary_filename))
intervention.load_state_dict(saved_state_dict)

return intervenable

Expand Down
3 changes: 3 additions & 0 deletions pyvene/models/interventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def __init__(self, **kwargs):
else:
self.source_representation = None

def set_source_representation(self, source_representation):
self.is_source_constant = True
self.register_buffer('source_representation', source_representation)

def set_interchange_dim(self, interchange_dim):
if isinstance(interchange_dim, int):
Expand Down
Loading

0 comments on commit d5beda8

Please sign in to comment.