-
Notifications
You must be signed in to change notification settings - Fork 288
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added llama 3.1 models and base for working on mllama #740
Conversation
The checks that are failing is because the system does not have the latest version of transformers |
tested making this model [mylesgoose/Llama-3.1-70B-Instruct-abliterated](mylesgoose/Llama-3.1-70B-Instruct-abliterated](https://huggingface.co/mylesgoose/Llama-3.1-70B-Instruct-abliterated/tree/main) |
# # Uncensor any LLM with abliteration
#
# > 🗣️ [Large Language Model Course](https://github.com/mlabonne/llm-course)
#
# ❤️ Created by [@maximelabonne](https://twitter.com/maximelabonne).
#!pip3 install -qqq transformers transformers_stream_generator tiktoken transformer_lens einops jaxtyping
import torch
import functools
import einops
import gc
from datasets import load_dataset
from tqdm import tqdm
from torch import Tensor
from typing import List
from transformer_lens import HookedTransformer, utils
from transformer_lens.hook_points import HookPoint
from transformers import AutoModelForCausalLM, AutoTokenizer, MllamaForConditionalGeneration, AutoProcessor, LlamaForCausalLM
from jaxtyping import Float, Int
from collections import defaultdict
# Turn automatic differentiation off to save GPU memory (credit: Undi95)
torch.set_grad_enabled(False)
def reformat_texts(texts):
return [[{"role": "user", "content": text}] for text in texts]
# Get harmful and harmless datasets
def get_harmful_instructions():
dataset = load_dataset('mlabonne/harmful_behaviors')
return reformat_texts(dataset['train']['text']), reformat_texts(dataset['test']['text'])
def get_harmless_instructions():
dataset = load_dataset('mlabonne/harmless_alpaca')
return reformat_texts(dataset['train']['text']), reformat_texts(dataset['test']['text'])
harmful_inst_train, harmful_inst_test = get_harmful_instructions()
harmless_inst_train, harmless_inst_test = get_harmless_instructions()
MODEL_ID = "meta-llama/Llama-3.1-70B-Instruct"
NEW_MODEL_ID = "mylesgoose/Llama-3.1-70B-Instruct-abliterated"
MODEL_TYPE = "meta-llama/Llama-3.1-70B-Instruct"
MODEL_PATH = 'meta-llama/Llama-3.1-70B-Instruct'
# Download and load model
#!git clone https://huggingface.co/{MODEL_ID} {MODEL_TYPE}
# Load model and tokenizer
model = HookedTransformer.from_pretrained_no_processing(
MODEL_PATH,
#local_files_only=True,
dtype=torch.bfloat16,
n_devices=8,
default_padding_side='right'
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
tokenizer.padding_side = 'right'
tokenizer.pad_token = tokenizer.eos_token
def tokenize_instructions(tokenizer, instructions):
return tokenizer.apply_chat_template(
instructions,
padding=True,
truncation=False,
return_tensors="pt",
return_dict=True,
add_generation_prompt=True,
).input_ids
n_inst_train = min(256, len(harmful_inst_train), len(harmless_inst_train))
# Tokenize datasets
harmful_tokens = tokenize_instructions(
tokenizer,
instructions=harmful_inst_train[:n_inst_train],
)
harmless_tokens = tokenize_instructions(
tokenizer,
instructions=harmless_inst_train[:n_inst_train],
)
# Define batch size based on available VRAM
batch_size = 24
# Initialize defaultdicts to store activations
harmful = defaultdict(list)
harmless = defaultdict(list)
# Process the training data in batches
num_batches = (n_inst_train + batch_size - 1) // batch_size
for i in tqdm(range(num_batches)):
print(i)
start_idx = i * batch_size
end_idx = min(n_inst_train, start_idx + batch_size)
# Run models on harmful and harmless prompts, cache activations
harmful_logits, harmful_cache = model.run_with_cache(
harmful_tokens[start_idx:end_idx],
names_filter=lambda hook_name: 'resid' in hook_name,
device='cpu',
reset_hooks_end=True
)
harmless_logits, harmless_cache = model.run_with_cache(
harmless_tokens[start_idx:end_idx],
names_filter=lambda hook_name: 'resid' in hook_name,
device='cpu',
reset_hooks_end=True
)
# Collect and store the activations
for key in harmful_cache:
harmful[key].append(harmful_cache[key])
harmless[key].append(harmless_cache[key])
# Flush RAM and VRAM
del harmful_logits, harmless_logits, harmful_cache, harmless_cache
gc.collect()
torch.cuda.empty_cache()
# Concatenate the cached activations
harmful = {k: torch.cat(v) for k, v in harmful.items()}
harmless = {k: torch.cat(v) for k, v in harmless.items()}
# Helper function to get activation index
def get_act_idx(cache_dict, act_name, layer):
key = (act_name, layer)
return cache_dict[utils.get_act_name(*key)]
# Compute difference of means between harmful and harmless activations at intermediate layers
activation_layers = ["resid_pre", "resid_mid", "resid_post"]
activation_refusals = defaultdict(list)
for layer_num in range(1, model.cfg.n_layers):
pos = -1 # Position index
for layer in activation_layers:
harmful_mean_act = get_act_idx(harmful, layer, layer_num)[:, pos, :].mean(dim=0)
harmless_mean_act = get_act_idx(harmless, layer, layer_num)[:, pos, :].mean(
dim=0
)
refusal_dir = harmful_mean_act - harmless_mean_act
refusal_dir = refusal_dir / refusal_dir.norm()
activation_refusals[layer].append(refusal_dir)
# Get all calculated potential refusal directions, sort them in descending order based on their mean
# Use a subset of layers if certain activations are not promising
selected_layers = ["resid_pre"]
activation_scored = sorted(
[
activation_refusals[layer][l - 1]
for l in range(1, model.cfg.n_layers)
for layer in selected_layers
],
key=lambda x: abs(x.mean()),
reverse=True,
)
def _generate_with_hooks(
model: HookedTransformer,
tokenizer: AutoTokenizer,
tokens: Int[Tensor, "batch_size seq_len"],
max_tokens_generated: int = 64,
fwd_hooks=[],
) -> List[str]:
all_tokens = torch.zeros(
(tokens.shape[0], tokens.shape[1] + max_tokens_generated),
dtype=torch.long,
device=tokens.device,
)
all_tokens[:, : tokens.shape[1]] = tokens
for i in range(max_tokens_generated):
with model.hooks(fwd_hooks=fwd_hooks):
logits = model(all_tokens[:, : -max_tokens_generated + i])
next_tokens = logits[:, -1, :].argmax(
dim=-1
) # greedy sampling (temperature=0)
all_tokens[:, -max_tokens_generated + i] = next_tokens
return tokenizer.batch_decode(
all_tokens[:, tokens.shape[1] :], skip_special_tokens=True
)
def get_generations(
model: HookedTransformer,
tokenizer: AutoTokenizer,
instructions: List[str],
fwd_hooks=[],
max_tokens_generated: int = 64,
batch_size: int = 4,
) -> List[str]:
generations = []
for i in tqdm(range(0, len(instructions), batch_size)):
tokens = tokenize_instructions(
tokenizer, instructions=instructions[i : i + batch_size]
)
generation = _generate_with_hooks(
model,
tokenizer,
tokens,
max_tokens_generated=max_tokens_generated,
fwd_hooks=fwd_hooks,
)
generations.extend(generation)
return generations
# Inference-time intervention hook
def direction_ablation_hook(
activation: Float[Tensor, "... d_act"],
hook: HookPoint,
direction: Float[Tensor, "d_act"],
):
if activation.device != direction.device:
direction = direction.to(activation.device)
proj = (
einops.einsum(
activation, direction.view(-1, 1), "... d_act, d_act single -> ... single"
)
* direction
)
return activation - proj
# Testing baseline
N_INST_TEST = 4
baseline_generations = get_generations(
model, tokenizer, harmful_inst_test[:N_INST_TEST], fwd_hooks=[]
)
# Evaluating layers defined earlier (needs human evaluation to determine best layer for refusal inhibition)
EVAL_N = 25 # Evaluate how many of the top N potential directions
evals = []
for refusal_dir in tqdm(activation_scored[:EVAL_N]):
hook_fn = functools.partial(direction_ablation_hook, direction=refusal_dir)
fwd_hooks = [
(utils.get_act_name(act_name, layer), hook_fn)
for layer in list(range(model.cfg.n_layers))
for act_name in activation_layers
]
intervention_generations = get_generations(
model, tokenizer, harmful_inst_test[:N_INST_TEST], fwd_hooks=fwd_hooks
)
evals.append(intervention_generations)
# Print generations for human evaluation
blacklist = ["I cannot", "I can't"]
for i in range(N_INST_TEST):
print(f"\033[1mINSTRUCTION {i}: {harmful_inst_test[i]}")
print(f"\nBASELINE COMPLETION:\n{baseline_generations[i]}\033[0m")
for layer_candidate in range(EVAL_N):
if not any(word in evals[layer_candidate][i] for word in blacklist):
print(f"\n---\n\nLAYER CANDIDATE #{layer_candidate} INTERVENTION COMPLETION:")
print(evals[layer_candidate][i])
torch.cuda.empty_cache()
#model.to('cpu')
torch.cuda.empty_cache()
def get_orthogonalized_matrix(
matrix: Float[Tensor, "... d_model"], vec: Float[Tensor, "d_model"]
) -> Float[Tensor, "... d_model"]:
proj = (
einops.einsum(
matrix, vec.view(-1, 1), "... d_model, d_model single -> ... single"
)
* vec
)
return matrix - proj
# Select the layer with the highest potential refusal direction
LAYER_CANDIDATE = 18
refusal_dir = activation_scored[LAYER_CANDIDATE]
# Orthogonalize the model's weights
if refusal_dir.device != model.W_E.device:
refusal_dir = refusal_dir.to(model.W_E.device)
model.W_E.data = get_orthogonalized_matrix(model.W_E, refusal_dir)
for block in tqdm(model.blocks):
if refusal_dir.device != block.attn.W_O.device:
refusal_dir = refusal_dir.to(block.attn.W_O.device)
block.attn.W_O.data = get_orthogonalized_matrix(block.attn.W_O, refusal_dir)
block.mlp.W_out.data = get_orthogonalized_matrix(block.mlp.W_out, refusal_dir)
# Generate text with abliterated model
orthogonalized_generations = get_generations(
model, tokenizer, harmful_inst_test[:N_INST_TEST], fwd_hooks=[]
)
# Print generations
for i in range(N_INST_TEST):
if len(baseline_generations) > i:
print(f"INSTRUCTION {i}: {harmful_inst_test[i]}")
print(f"\033[92mBASELINE COMPLETION:\n{baseline_generations[i]}")
print(f"\033[91mINTERVENTION COMPLETION:\n{evals[LAYER_CANDIDATE][i]}")
print(f"\033[95mORTHOGONALIZED COMPLETION:\n{orthogonalized_generations[i]}\n")
model_state_dict = model.state_dict()
model_state_dict_keys = list(model_state_dict.keys())
# Display the first 100 keys for an overview
print(model_state_dict_keys)
# %%
# Convert model back to HF safetensors
hf_model = LlamaForCausalLM.from_pretrained(MODEL_TYPE, torch_dtype=torch.bfloat16)
lm_model = hf_model.model # Adjust this according to your model's internal architecture
# Load the state dictionary from your custom model
state_dict = model.state_dict()
# Assign the embedding weights
hf_model.get_input_embeddings().weight.data = state_dict["embed.W_E"].cpu()
# Iterate through the layers to adjust the weights based on your model's state_dict
for l in range(model.cfg.n_layers):
if hasattr(lm_model.layers[l], 'self_attn'):
# Set the weights for layers with self_attn
lm_model.layers[l].self_attn.o_proj.weight.data = torch.nn.Parameter(
einops.rearrange(
state_dict[f"blocks.{l}.attn.W_O"], "n h m->m (n h)", n=model.cfg.n_heads
).contiguous()
).cpu()
elif hasattr(lm_model.layers[l], 'cross_attn'):
# Set the weights for layers with cross_attn
lm_model.layers[l].cross_attn.o_proj.weight.data = torch.nn.Parameter(
einops.rearrange(
state_dict[f"blocks.{l}.attn.W_O"], "n h m->m (n h)", n=model.cfg.n_heads
).contiguous()
).cpu()
# Assign the feed-forward weights
lm_model.layers[l].mlp.down_proj.weight.data = torch.nn.Parameter(
torch.transpose(state_dict[f"blocks.{l}.mlp.W_out"], 0, 1).contiguous()
).cpu()
# Push it to the Hugging Face Hub
#hf_model.push_to_hub(NEW_MODEL_ID)
local_dir = "/home/myles/abliteration/mylesgoose/Llama-3.1-70B-Instruct-abliterated"
hf_model.save_pretrained(local_dir)
#hf_model.push_to_hub(NEW_MODEL_ID)
|
@mylesgoose I'll play with it to get it passing. Can you change the name of the branch you are merging from? If you change the name to something besides dev, or main, I can make my own edits on the branch in order to resolve these sorts of issues. When you are merging from a branch named main, I cannot push to the PR. |
I changed name to mllama and it closed the pull request. but I had given you write access to the code and I think you changed to dev also. sorry I just got the message and changed the branch. perhaps you changed in the mean time. 🤔 I think perhaps we can remove the few lines related to mllama so it can be merged. to give the 3.1 support now as that's working. |
Description
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
Fixes # (issue)
Type of change
Please delete options that are not relevant.
Screenshots
Please attach before and after screenshots of the change if applicable.
Checklist: