From 02480a09cc1cb63fa120fed4c1ba9f5f0662de7d Mon Sep 17 00:00:00 2001 From: sangkeun00 Date: Sat, 27 Apr 2024 00:00:10 -0400 Subject: [PATCH] support different datasets --- .../language_modeling/compute_influence.py | 32 +++++--- examples/language_modeling/extract_log.py | 7 +- examples/language_modeling/generate.py | 39 ++++++++++ examples/language_modeling/utils.py | 78 +++++++++++++------ 4 files changed, 121 insertions(+), 35 deletions(-) create mode 100755 examples/language_modeling/generate.py diff --git a/examples/language_modeling/compute_influence.py b/examples/language_modeling/compute_influence.py index 781ba1bc..b36efb20 100644 --- a/examples/language_modeling/compute_influence.py +++ b/examples/language_modeling/compute_influence.py @@ -30,23 +30,35 @@ def main(): type=str, default="/data/tir/projects/tir3/users/sangkeuc/gpt/results", ) - parser.add_argument("--model_name", type=str, default="gpt2") - parser.add_argument("--hessian", type=str, default="kfac") + parser.add_argument("--model_name", type=str, default="gpt2-xl") + parser.add_argument("--data_path", type=str, default="wikitext") + parser.add_argument("--data_name", type=str, default=None) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--hessian", type=str, default="raw") parser.add_argument("--lora", type=str, default="random") - parser.add_argument("--split", type=str, default="valid") + parser.add_argument("--split", type=str, default="train") parser.add_argument("--mlp_only", action="store_true") - parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--layerwise", action="store_true") parser.add_argument("--damping", type=float, default=1e-5) args = parser.parse_args() set_seed(0) accelerator = Accelerator() + influence_groups = None + if args.layerwise: + layer_id = "h" if args.model_name == "gpt2-xl" else "layers" + layer_num = 48 if args.model_name == "gpt2-xl" else 32 + influence_groups = [f".{layer_id}.{i}." for i in range(layer_num)] # prepare model & data loader model = get_model(model_name=args.model_name, cache_dir=args.cache_dir) - tokenizer = get_tokenizer(model_name=args.model_name, cache_dir=args.cache_dir) + tokenizer = get_tokenizer( + model_name=args.model_name, cache_dir=args.cache_dir, add_padding_token=True + ) data_loader = get_loader( model_name=args.model_name, + data_path=args.data_path, + data_name=args.data_name, tokenizer=tokenizer, batch_size=args.batch_size, cache_dir=args.cache_dir, @@ -91,9 +103,11 @@ def main(): test_log = logix.get_log() merged_test_logs.append(copy.deepcopy(test_log)) - if idx == 7 or idx == len(data_loader) - 1: + if idx == 12 or idx == len(data_loader) - 1: merged_test_log = merge_logs(merged_test_logs) - result = run.influence.compute_influence_all(merged_test_log, log_loader) + result = run.influence.compute_influence_all( + merged_test_log, log_loader, influence_groups=influence_groups + ) merged_test_logs = [] break @@ -104,8 +118,8 @@ def main(): if not os.path.exists(save_dir): os.makedirs(save_dir) torch.save(result["influence"], os.path.join(save_dir, "scores.pt")) - torch.save(result["src_ids"], os.path.join(save_dir, "train_ids.pt")) - torch.save(result["tgt_ids"], os.path.join(save_dir, "test_ids.pt")) + torch.save(result["src_ids"], os.path.join(save_dir, "test_ids.pt")) + torch.save(result["tgt_ids"], os.path.join(save_dir, "train_ids.pt")) if __name__ == "__main__": diff --git a/examples/language_modeling/extract_log.py b/examples/language_modeling/extract_log.py index 95bdecd6..8f51be40 100644 --- a/examples/language_modeling/extract_log.py +++ b/examples/language_modeling/extract_log.py @@ -26,12 +26,15 @@ def main(): default="/data/tir/projects/tir3/users/sangkeuc/huggingface", ) parser.add_argument("--model_name", type=str, default="gpt2") + parser.add_argument("--data_path", type=str, default="wikitext") + parser.add_argument("--data_name", type=str, default=None) parser.add_argument("--batch_size", type=int, default=8) - parser.add_argument("--hessian", type=str, default="kfac") + parser.add_argument("--hessian", type=str, default="raw") parser.add_argument("--lora", type=str, default="random") parser.add_argument("--save", type=str, default="grad") parser.add_argument("--mlp_only", action="store_true") args = parser.parse_args() + print(args) set_seed(0) accelerator = Accelerator() @@ -41,6 +44,8 @@ def main(): tokenizer = get_tokenizer(model_name=args.model_name, cache_dir=args.cache_dir) data_loader = get_loader( model_name=args.model_name, + data_path=args.data_path, + data_name=args.data_name, tokenizer=tokenizer, batch_size=args.batch_size, cache_dir=args.cache_dir, diff --git a/examples/language_modeling/generate.py b/examples/language_modeling/generate.py new file mode 100755 index 00000000..160bd250 --- /dev/null +++ b/examples/language_modeling/generate.py @@ -0,0 +1,39 @@ +import transformers +import torch + +model_id = "meta-llama/Meta-Llama-3-8B-Instruct" + +pipeline = transformers.pipeline( + "text-generation", + model=model_id, + model_kwargs={"torch_dtype": torch.bfloat16, "cache_dir": "./cache"}, + device="cuda", +) + +messages = [ + { + "role": "system", + "content": "You are a pirate chatbot who always responds in pirate speak!", + }, + {"role": "user", "content": "Give three tips for staying healthy."}, +] + +prompt = pipeline.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True +) + +terminators = [ + pipeline.tokenizer.eos_token_id, + pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>"), +] + +outputs = pipeline( + prompt, + max_new_tokens=256, + eos_token_id=terminators, + do_sample=True, + temperature=0.6, + top_p=0.9, +) +print(outputs[0]["generated_text"]) +# print(outputs[0]["generated_text"][len(prompt):]) diff --git a/examples/language_modeling/utils.py b/examples/language_modeling/utils.py index b036eba8..ace99737 100644 --- a/examples/language_modeling/utils.py +++ b/examples/language_modeling/utils.py @@ -74,7 +74,9 @@ def get_model(model_name, cache_dir) -> nn.Module: return LanguageModel(model_name, cache_dir) -def get_tokenizer(model_name, cache_dir) -> PreTrainedTokenizer: +def get_tokenizer( + model_name, cache_dir, add_padding_token=False +) -> PreTrainedTokenizer: tokenizer = AutoTokenizer.from_pretrained( model_name, use_fast=True, @@ -82,7 +84,7 @@ def get_tokenizer(model_name, cache_dir) -> PreTrainedTokenizer: cache_dir=cache_dir, ) - if tokenizer.pad_token is None: + if tokenizer.pad_token is None and add_padding_token: print("No pad token found. Setting `` as a pad token.") tokenizer.pad_token = "" if "" not in tokenizer.get_vocab(): @@ -95,34 +97,53 @@ def get_tokenizer(model_name, cache_dir) -> PreTrainedTokenizer: def get_dataset( model_name: str, tokenizer: PreTrainedTokenizer, + data_path: str, + data_name: Optional[str] = None, split: str = "train", + sample_ratio: float = 0.005, cache_dir: str = None, ) -> torch.utils.data.DataLoader: - assert split in ["train", "valid", "generated", "external"] - - model_name_split = model_name.split("/")[-1] - data_name = "wiki" if split in ["train", "valid"] else split - split_key = "validation" if split == "valid" else "train" - if os.path.exists(os.path.join(cache_dir, f"{model_name_split}_{data_name}.pt")): + model_name_strip = model_name.split("/")[-1] + save_data_name = data_path if data_name is None else data_name + save_data_name = save_data_name.split("/")[-1] + if os.path.exists( + os.path.join(cache_dir, f"{model_name_strip}_{save_data_name}.pt") + ): print("[*] Loading from cached data...") lm_datasets = load_from_disk( - os.path.join(cache_dir, f"{model_name_split}_{data_name}.pt") + os.path.join(cache_dir, f"{model_name_strip}_{save_data_name}.pt") ) - return lm_datasets[split_key] + return lm_datasets[split] # Prepare raw dataset - if split in ["train", "valid"]: - data_path = "wikitext" - data_kwargs = {"name": "wikitext-103-raw-v1"} - elif split in ["external"]: - data_path = "json" - data_kwargs = {"data_files": "./custom_data/external/data.json"} + if data_path == "external": + data_kwargs = { + "path": "json", + "data_files": "./custom_data/external/data.json", + "cache_dir": cachd_dir, + "num_proc": 4, + } + elif data_path == "generated": + data_kwargs = { + "path": "json", + "data_files": f"./custom_data/generated/{model_name_strip}/data.json", + "cache_dir": cache_dir, + "num_proc": 4, + } else: - data_path = "json" data_kwargs = { - "data_files": f"./custom_data/generated/{model_name_split}/data.json" + "path": data_path, + "name": data_name, + "cache_dir": cache_dir, + "num_proc": 4, } - raw_datasets = load_dataset(data_path, **data_kwargs) + raw_datasets = load_dataset(**data_kwargs) + + if sample_ratio is not None: + sampled_train = raw_datasets["train"].train_test_split( + test_size=0.005, shuffle=True, seed=42 + ) + raw_datasets["train"] = sampled_train["test"] # Tokenize dataset column_names = raw_datasets["train"].column_names @@ -134,14 +155,14 @@ def tokenize_function(examples): tokenized_datasets = raw_datasets.map( tokenize_function, batched=True, - num_proc=None, + num_proc=4, remove_columns=column_names, load_from_cache_file=True, desc="Running tokenizer on dataset", ) # Group text - if split in ["train", "valid"]: + if data_path not in ["generated", "external"]: block_size = 512 def group_texts(examples): @@ -171,28 +192,35 @@ def group_texts(examples): lm_datasets = tokenized_datasets.map( group_texts, batched=True, - num_proc=None, + num_proc=4, load_from_cache_file=True, desc=f"Grouping texts in chunks of {block_size}", ) print("[*] Saving data to disk...") lm_datasets.save_to_disk( - os.path.join(cache_dir, f"{model_name_split}_{data_name}.pt") + os.path.join(cache_dir, f"{model_name_strip}_{save_data_name}.pt") ) - return lm_datasets[split_key] + return lm_datasets[split] def get_loader( model_name: str, tokenizer: PreTrainedTokenizer, batch_size: int, + data_path: str, + data_name: Optional[str] = None, split: str = "train", cache_dir: str = None, ) -> torch.utils.data.DataLoader: dataset = get_dataset( - model_name=model_name, tokenizer=tokenizer, split=split, cache_dir=cache_dir + model_name=model_name, + tokenizer=tokenizer, + data_path=data_path, + data_name=data_name, + split=split, + cache_dir=cache_dir, ) dataloader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, shuffle=False, collate_fn=default_data_collator