Skip to content

Commit

Permalink
support different datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
sangkeun00 committed Apr 27, 2024
1 parent 97e5f58 commit 02480a0
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 35 deletions.
32 changes: 23 additions & 9 deletions examples/language_modeling/compute_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,35 @@ def main():
type=str,
default="/data/tir/projects/tir3/users/sangkeuc/gpt/results",
)
parser.add_argument("--model_name", type=str, default="gpt2")
parser.add_argument("--hessian", type=str, default="kfac")
parser.add_argument("--model_name", type=str, default="gpt2-xl")
parser.add_argument("--data_path", type=str, default="wikitext")
parser.add_argument("--data_name", type=str, default=None)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--hessian", type=str, default="raw")
parser.add_argument("--lora", type=str, default="random")
parser.add_argument("--split", type=str, default="valid")
parser.add_argument("--split", type=str, default="train")
parser.add_argument("--mlp_only", action="store_true")
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--layerwise", action="store_true")
parser.add_argument("--damping", type=float, default=1e-5)
args = parser.parse_args()

set_seed(0)
accelerator = Accelerator()
influence_groups = None
if args.layerwise:
layer_id = "h" if args.model_name == "gpt2-xl" else "layers"
layer_num = 48 if args.model_name == "gpt2-xl" else 32
influence_groups = [f".{layer_id}.{i}." for i in range(layer_num)]

# prepare model & data loader
model = get_model(model_name=args.model_name, cache_dir=args.cache_dir)
tokenizer = get_tokenizer(model_name=args.model_name, cache_dir=args.cache_dir)
tokenizer = get_tokenizer(
model_name=args.model_name, cache_dir=args.cache_dir, add_padding_token=True
)
data_loader = get_loader(
model_name=args.model_name,
data_path=args.data_path,
data_name=args.data_name,
tokenizer=tokenizer,
batch_size=args.batch_size,
cache_dir=args.cache_dir,
Expand Down Expand Up @@ -91,9 +103,11 @@ def main():
test_log = logix.get_log()
merged_test_logs.append(copy.deepcopy(test_log))

if idx == 7 or idx == len(data_loader) - 1:
if idx == 12 or idx == len(data_loader) - 1:
merged_test_log = merge_logs(merged_test_logs)
result = run.influence.compute_influence_all(merged_test_log, log_loader)
result = run.influence.compute_influence_all(
merged_test_log, log_loader, influence_groups=influence_groups
)
merged_test_logs = []
break

Expand All @@ -104,8 +118,8 @@ def main():
if not os.path.exists(save_dir):
os.makedirs(save_dir)
torch.save(result["influence"], os.path.join(save_dir, "scores.pt"))
torch.save(result["src_ids"], os.path.join(save_dir, "train_ids.pt"))
torch.save(result["tgt_ids"], os.path.join(save_dir, "test_ids.pt"))
torch.save(result["src_ids"], os.path.join(save_dir, "test_ids.pt"))
torch.save(result["tgt_ids"], os.path.join(save_dir, "train_ids.pt"))


if __name__ == "__main__":
Expand Down
7 changes: 6 additions & 1 deletion examples/language_modeling/extract_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,15 @@ def main():
default="/data/tir/projects/tir3/users/sangkeuc/huggingface",
)
parser.add_argument("--model_name", type=str, default="gpt2")
parser.add_argument("--data_path", type=str, default="wikitext")
parser.add_argument("--data_name", type=str, default=None)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--hessian", type=str, default="kfac")
parser.add_argument("--hessian", type=str, default="raw")
parser.add_argument("--lora", type=str, default="random")
parser.add_argument("--save", type=str, default="grad")
parser.add_argument("--mlp_only", action="store_true")
args = parser.parse_args()
print(args)

set_seed(0)
accelerator = Accelerator()
Expand All @@ -41,6 +44,8 @@ def main():
tokenizer = get_tokenizer(model_name=args.model_name, cache_dir=args.cache_dir)
data_loader = get_loader(
model_name=args.model_name,
data_path=args.data_path,
data_name=args.data_name,
tokenizer=tokenizer,
batch_size=args.batch_size,
cache_dir=args.cache_dir,
Expand Down
39 changes: 39 additions & 0 deletions examples/language_modeling/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import transformers
import torch

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

pipeline = transformers.pipeline(
"text-generation",
model=model_id,
model_kwargs={"torch_dtype": torch.bfloat16, "cache_dir": "./cache"},
device="cuda",
)

messages = [
{
"role": "system",
"content": "You are a pirate chatbot who always responds in pirate speak!",
},
{"role": "user", "content": "Give three tips for staying healthy."},
]

prompt = pipeline.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)

terminators = [
pipeline.tokenizer.eos_token_id,
pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>"),
]

outputs = pipeline(
prompt,
max_new_tokens=256,
eos_token_id=terminators,
do_sample=True,
temperature=0.6,
top_p=0.9,
)
print(outputs[0]["generated_text"])
# print(outputs[0]["generated_text"][len(prompt):])
78 changes: 53 additions & 25 deletions examples/language_modeling/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,17 @@ def get_model(model_name, cache_dir) -> nn.Module:
return LanguageModel(model_name, cache_dir)


def get_tokenizer(model_name, cache_dir) -> PreTrainedTokenizer:
def get_tokenizer(
model_name, cache_dir, add_padding_token=False
) -> PreTrainedTokenizer:
tokenizer = AutoTokenizer.from_pretrained(
model_name,
use_fast=True,
trust_remote_code=("gpt2" in model_name),
cache_dir=cache_dir,
)

if tokenizer.pad_token is None:
if tokenizer.pad_token is None and add_padding_token:
print("No pad token found. Setting `<pad>` as a pad token.")
tokenizer.pad_token = "<pad>"
if "<pad>" not in tokenizer.get_vocab():
Expand All @@ -95,34 +97,53 @@ def get_tokenizer(model_name, cache_dir) -> PreTrainedTokenizer:
def get_dataset(
model_name: str,
tokenizer: PreTrainedTokenizer,
data_path: str,
data_name: Optional[str] = None,
split: str = "train",
sample_ratio: float = 0.005,
cache_dir: str = None,
) -> torch.utils.data.DataLoader:
assert split in ["train", "valid", "generated", "external"]

model_name_split = model_name.split("/")[-1]
data_name = "wiki" if split in ["train", "valid"] else split
split_key = "validation" if split == "valid" else "train"
if os.path.exists(os.path.join(cache_dir, f"{model_name_split}_{data_name}.pt")):
model_name_strip = model_name.split("/")[-1]
save_data_name = data_path if data_name is None else data_name
save_data_name = save_data_name.split("/")[-1]
if os.path.exists(
os.path.join(cache_dir, f"{model_name_strip}_{save_data_name}.pt")
):
print("[*] Loading from cached data...")
lm_datasets = load_from_disk(
os.path.join(cache_dir, f"{model_name_split}_{data_name}.pt")
os.path.join(cache_dir, f"{model_name_strip}_{save_data_name}.pt")
)
return lm_datasets[split_key]
return lm_datasets[split]

# Prepare raw dataset
if split in ["train", "valid"]:
data_path = "wikitext"
data_kwargs = {"name": "wikitext-103-raw-v1"}
elif split in ["external"]:
data_path = "json"
data_kwargs = {"data_files": "./custom_data/external/data.json"}
if data_path == "external":
data_kwargs = {
"path": "json",
"data_files": "./custom_data/external/data.json",
"cache_dir": cachd_dir,
"num_proc": 4,
}
elif data_path == "generated":
data_kwargs = {
"path": "json",
"data_files": f"./custom_data/generated/{model_name_strip}/data.json",
"cache_dir": cache_dir,
"num_proc": 4,
}
else:
data_path = "json"
data_kwargs = {
"data_files": f"./custom_data/generated/{model_name_split}/data.json"
"path": data_path,
"name": data_name,
"cache_dir": cache_dir,
"num_proc": 4,
}
raw_datasets = load_dataset(data_path, **data_kwargs)
raw_datasets = load_dataset(**data_kwargs)

if sample_ratio is not None:
sampled_train = raw_datasets["train"].train_test_split(
test_size=0.005, shuffle=True, seed=42
)
raw_datasets["train"] = sampled_train["test"]

# Tokenize dataset
column_names = raw_datasets["train"].column_names
Expand All @@ -134,14 +155,14 @@ def tokenize_function(examples):
tokenized_datasets = raw_datasets.map(
tokenize_function,
batched=True,
num_proc=None,
num_proc=4,
remove_columns=column_names,
load_from_cache_file=True,
desc="Running tokenizer on dataset",
)

# Group text
if split in ["train", "valid"]:
if data_path not in ["generated", "external"]:
block_size = 512

def group_texts(examples):
Expand Down Expand Up @@ -171,28 +192,35 @@ def group_texts(examples):
lm_datasets = tokenized_datasets.map(
group_texts,
batched=True,
num_proc=None,
num_proc=4,
load_from_cache_file=True,
desc=f"Grouping texts in chunks of {block_size}",
)

print("[*] Saving data to disk...")
lm_datasets.save_to_disk(
os.path.join(cache_dir, f"{model_name_split}_{data_name}.pt")
os.path.join(cache_dir, f"{model_name_strip}_{save_data_name}.pt")
)

return lm_datasets[split_key]
return lm_datasets[split]


def get_loader(
model_name: str,
tokenizer: PreTrainedTokenizer,
batch_size: int,
data_path: str,
data_name: Optional[str] = None,
split: str = "train",
cache_dir: str = None,
) -> torch.utils.data.DataLoader:
dataset = get_dataset(
model_name=model_name, tokenizer=tokenizer, split=split, cache_dir=cache_dir
model_name=model_name,
tokenizer=tokenizer,
data_path=data_path,
data_name=data_name,
split=split,
cache_dir=cache_dir,
)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=False, collate_fn=default_data_collator
Expand Down

0 comments on commit 02480a0

Please sign in to comment.