From beeedaac457e324a273bc2c376d792c5c8d4ce74 Mon Sep 17 00:00:00 2001 From: llauraa23 Date: Wed, 24 Jan 2024 08:12:48 +0000 Subject: [PATCH] code cleanup for d2l demo. In SFT, make data collator, formatting function, whether to disable evalution configurable --- ...l.py => supervised_finetuning_demo_d2l.py} | 29 ++-- pykoi/rlhf/config.py | 12 ++ pykoi/rlhf/customize_data_collator.py | 15 +- pykoi/rlhf/supervised_finetuning.py | 141 +++++++++++------- 4 files changed, 123 insertions(+), 74 deletions(-) rename example/rlhf/{supervised_finetuning_d2l.py => supervised_finetuning_demo_d2l.py} (65%) diff --git a/example/rlhf/supervised_finetuning_d2l.py b/example/rlhf/supervised_finetuning_demo_d2l.py similarity index 65% rename from example/rlhf/supervised_finetuning_d2l.py rename to example/rlhf/supervised_finetuning_demo_d2l.py index 172d1f3..8493195 100644 --- a/example/rlhf/supervised_finetuning_d2l.py +++ b/example/rlhf/supervised_finetuning_demo_d2l.py @@ -1,8 +1,9 @@ """Demo for the supervised fine tuning. -python -m example.rlhf.supervised_finetuning_demo +python -m example.rlhf.supervised_finetuning_demo_d2l """ +from peft import LoraConfig from pykoi.chat import QuestionAnswerDatabase from pykoi.chat.db.constants import (QA_CSV_HEADER_ANSWER, QA_CSV_HEADER_ID, QA_CSV_HEADER_QUESTION, @@ -26,27 +27,29 @@ print("My local database has {} samples in total".format(my_data_pd.shape[0])) # run supervised finetuning -from peft import LoraConfig config = RLHFConfig(base_model_path="mistralai/Mistral-7B-Instruct-v0.1", dataset_type="local_csv", dataset_name="data/chapter22_trnvalfromseed_data_processed.csv", - train_test_split_ratio=0, # ratio for test set DH:TODO: COBINE TRAIN AND EVAL + train_test_split_ratio=0, # ratio for test set DH:TODO: COBINE TRAIN AND EVAL max_seq_length=896, - per_device_eval_batch_size = 1, - log_freq=20, - # dh: NOTE: 1 EPOCH iterates the dataset once. So log freq 20 means iterating 20 entries when training batch size = 1. + per_device_eval_batch_size=1, + log_freq=20, + # dh: NOTE: 1 EPOCH iterates the dataset once. So log freq 20 means iterating 20 entries when training batch size = 1. # (i.e., log_freq = 0.12 epoch when the dataset has 166 entires). save_freq=40000, num_train_epochs=20, - max_steps=-1, # if a positive number is given, it will override num_train_epochs + max_steps=-1, # if a positive number is given, it will override num_train_epochs device_map="auto", - lora_config_rl = LoraConfig( - r=512, - lora_alpha=1024, - lora_dropout=0.05, - target_modules=["q_proj","k_proj","v_proj","o_proj",], # "gate_proj","up_proj","down_proj",], #"lm_head",], + lora_config_rl=LoraConfig( + r=512, + lora_alpha=1024, + lora_dropout=0.05, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj", ], # "gate_proj","up_proj","down_proj",], #"lm_head",], bias="none", task_type="CAUSAL_LM" - ), + ), + data_collator="DataCollatorForCompletionOnlyLM", + no_evaluation=True, + prepare_text="d2l", ) rlhf_step1_sft = SupervisedFinetuning(config) rlhf_step1_sft.train_and_save("./models/rlhf_step1_sft") diff --git a/pykoi/rlhf/config.py b/pykoi/rlhf/config.py index 10d7184..ffd1ff1 100644 --- a/pykoi/rlhf/config.py +++ b/pykoi/rlhf/config.py @@ -184,6 +184,18 @@ class RLHFConfig: ), metadata={"help": "LoRA configuration."}, ) + data_collator: Optional[str] = field( + default=None, + metadata={"help": "The name of data collator to use for training."}, + ) + no_evaluation: Optional[bool] = field( + default=False, + metadata={"help": "Whether to disable evaluations during training."}, + ) + prepare_text: Optional[str] = field( + default="sample", + metadata={"help": "How to prepare the text for the model."}, + ) # Step 2 reward modeling parameters reward_model_path: Optional[str] = field( diff --git a/pykoi/rlhf/customize_data_collator.py b/pykoi/rlhf/customize_data_collator.py index 5cc8c1e..833269a 100644 --- a/pykoi/rlhf/customize_data_collator.py +++ b/pykoi/rlhf/customize_data_collator.py @@ -1,8 +1,11 @@ -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Union from transformers import DataCollatorForLanguageModeling import numpy as np + + class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling): - def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: + def torch_call( + self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: batch = super().torch_call(examples) # The prompt ends with the response key plus a newline. We encode this and then try to find it in the @@ -16,7 +19,8 @@ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> D for i in range(len(examples)): response_token_ids_start_idx = None - for idx in np.where(batch["labels"][i] == response_token_ids[0])[0]: + for idx in np.where( + batch["labels"][i] == response_token_ids[0])[0]: response_token_ids_start_idx = idx break @@ -27,9 +31,10 @@ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> D response_token_ids_end_idx = response_token_ids_start_idx + 1 - # Make pytorch loss function ignore all tokens up through the end of the response key + # Make pytorch loss function ignore all tokens up through the end + # of the response key labels[i, :response_token_ids_end_idx] = -100 batch["labels"] = labels - return batch \ No newline at end of file + return batch diff --git a/pykoi/rlhf/supervised_finetuning.py b/pykoi/rlhf/supervised_finetuning.py index 27c8369..881831f 100644 --- a/pykoi/rlhf/supervised_finetuning.py +++ b/pykoi/rlhf/supervised_finetuning.py @@ -22,6 +22,7 @@ from pykoi.telemetry.telemetry import Telemetry from pykoi.rlhf.customize_data_collator import DataCollatorForCompletionOnlyLM + class SupervisedFinetuning: """ A class representing the supervised finetuning trainer. @@ -37,7 +38,10 @@ class SupervisedFinetuning: trainer (SFTTrainer): The trainer object used for training the model. """ - def __init__(self, rlhf_config: RLHFConfig, enable_telemetry: bool = True) -> None: + def __init__( + self, + rlhf_config: RLHFConfig, + enable_telemetry: bool = True) -> None: """ Initializes the SFTTrainer object. @@ -47,17 +51,18 @@ def __init__(self, rlhf_config: RLHFConfig, enable_telemetry: bool = True) -> No """ self._telemetry = Telemetry(enable_telemetry) self._rlhf_config = rlhf_config - self.tokenizer = AutoTokenizer.from_pretrained(rlhf_config.base_model_path) + self.tokenizer = AutoTokenizer.from_pretrained( + rlhf_config.base_model_path) # dh: add special tokens to tokenizer self.tokenizer.pad_token = self.tokenizer.eos_token END_KEY = "### End" INSTRUCTION_KEY = "### Instruction:" RESPONSE_KEY = "### Response:" RESPONSE_KEY_NL = f"{RESPONSE_KEY}\n" - self.tokenizer.add_special_tokens({"additional_special_tokens": [END_KEY, INSTRUCTION_KEY, RESPONSE_KEY_NL]}) + self.tokenizer.add_special_tokens( + {"additional_special_tokens": [END_KEY, INSTRUCTION_KEY, RESPONSE_KEY_NL]}) self.num_proc = ( - self._rlhf_config.num_workers if not self._rlhf_config.streaming else None - ) + self._rlhf_config.num_workers if not self._rlhf_config.streaming else None) self.dataset = self.create_datasets(self.tokenizer, self._rlhf_config) self.torch_dtype = torch.bfloat16 if self._rlhf_config.bf16 else torch.float16 # self.torch_dtype = torch.bfloat16 if bf16 else (torch.float16 if fp16 else torch.float32) @@ -77,8 +82,7 @@ def __init__(self, rlhf_config: RLHFConfig, enable_telemetry: bool = True) -> No gradient_accumulation_steps=self._rlhf_config.gradient_accumulation_steps, gradient_checkpointing=self._rlhf_config.gradient_checkpointing, gradient_checkpointing_kwargs={ - "use_reentrant": self._rlhf_config.gradient_checkpointing_use_reentrant - }, + "use_reentrant": self._rlhf_config.gradient_checkpointing_use_reentrant}, fp16=self._rlhf_config.fp16, bf16=self._rlhf_config.bf16, weight_decay=self._rlhf_config.weight_decay, @@ -93,18 +97,20 @@ def __init__(self, rlhf_config: RLHFConfig, enable_telemetry: bool = True) -> No ) # resize the token embeddings to include the added special tokens self.model.resize_token_embeddings(len(self.tokenizer)) - - # dh: try the customized data collator that only predicts the answer part - data_collator = DataCollatorForCompletionOnlyLM( - tokenizer=self.tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8 - ) + data_collator = None + if self._rlhf_config.data_collator == "DataCollatorForCompletionOnlyLM": + # dh: try the customized data collator that only predicts the + # answer part + data_collator = DataCollatorForCompletionOnlyLM( + tokenizer=self.tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8) self.trainer = SFTTrainer( model=self.model, args=self.training_args, train_dataset=self.dataset["train"], eval_dataset=self.dataset["eval"], - peft_config=self._rlhf_config.lora_config_rl, ## TODO: DH: LoraConfig MAY BE IGNORED IF USING FROM_PRETRAINED + peft_config=self._rlhf_config.lora_config_rl, + # TODO: DH: LoraConfig MAY BE IGNORED IF USING FROM_PRETRAINED packing=True, data_collator=data_collator, dataset_text_field="text", @@ -163,8 +169,9 @@ def save(self, output_path=None): def train_and_save(self, output_path=None): start_event = SFTStartEvent( - start_time=time.time(), date_time=datetime.utcfromtimestamp(time.time()) - ) + start_time=time.time(), + date_time=datetime.utcfromtimestamp( + time.time())) self._telemetry.capture(start_event) self.trainer.train() self.save(output_path) @@ -180,10 +187,8 @@ def prepare_sample_text(self, example): """Prepare the text from a sample of the dataset.""" text = ( f"Question: {example[self._rlhf_config.question_title]}\n\n " - f" Answer: {example[self._rlhf_config.answer_title]}" - ) + f" Answer: {example[self._rlhf_config.answer_title]}") return text - def prepare_d2l_text(self, example): """Prepare the text from a sample of the d2l dataset .""" @@ -198,7 +203,8 @@ def prepare_d2l_text(self, example): DEFAULT_SEED = 42 # This is a training prompt that does not contain an input string. The instruction by itself has enough information - # to respond. For example, the instruction might ask for the year a historic figure was born. + # to respond. For example, the instruction might ask for the year a + # historic figure was born. PROMPT_NO_INPUT_FORMAT = """{intro} {instruction_key} {instruction} @@ -214,7 +220,8 @@ def prepare_d2l_text(self, example): ) # This is a training prompt that contains an input string that serves as context for the instruction. For example, - # the input might be a passage from Wikipedia and the intruction is to extract some information from it. + # the input might be a passage from Wikipedia and the intruction is to + # extract some information from it. PROMPT_WITH_INPUT_FORMAT = """{intro} {instruction_key} {instruction} @@ -232,14 +239,17 @@ def prepare_d2l_text(self, example): response="{response}", end_key=END_KEY, ) - + context = example.get("context") if context: - text = PROMPT_WITH_INPUT_FORMAT.format(instruction=example["instruction"], response=example["response"], input=context) + text = PROMPT_WITH_INPUT_FORMAT.format( + instruction=example["instruction"], + response=example["response"], + input=context) else: - text = PROMPT_NO_INPUT_FORMAT.format(instruction=example["instruction"], response=example["instruction"]) - - + text = PROMPT_NO_INPUT_FORMAT.format( + instruction=example["instruction"], + response=example["instruction"]) return text @@ -258,13 +268,16 @@ def create_datasets(self, tokenizer, args): ) dataset = Dataset.from_dict(my_data_pd) elif args.dataset_type == "local_csv": - ## this way will load 1660 enetries + # this way will load 1660 enetries # dataset = load_dataset("csv", data_files=args.dataset_name) # dataset = dataset[args.split] # Convert DatasetDict to Dataset # this way will load 166 entries - dataset = load_dataset("csv", data_files=args.dataset_name, split='train[:10%]') + dataset = load_dataset( + "csv", + data_files=args.dataset_name, + split='train[:10%]') elif args.dataset_type == "huggingface": dataset = load_dataset( @@ -281,34 +294,50 @@ def create_datasets(self, tokenizer, args): "No (supported) data files or dataset script found" f" {args.dataset_type}" ) - - # dh: temp change. No test set - # dataset = dataset.train_test_split( - # test_size=args.train_test_split_ratio, seed=args.seed - # ) - print( - f"Size of the train set: {len(dataset)}. " - #f"Size of the train set: {len(dataset['train'])}. " - #f" Size of the validation set: {len(dataset['test'])}" - ) - train_dataset = ConstantLengthDataset( - tokenizer, - dataset, - #dataset["train"], #dh: temp change. No test set - formatting_func=self.prepare_d2l_text, - infinite=True, - seq_length=args.max_seq_length, - # chars_per_token=chars_per_token, - ) - # temp change: no test set - # eval_dataset = ConstantLengthDataset( - # tokenizer, - # dataset["test"], - # formatting_func=self.prepare_d2l_text, - # infinite=False, - # seq_length=args.max_seq_length, - # # chars_per_token=chars_per_token, - # ) - eval_dataset = None + if args.prepare_text == "d2l": + self.prepare_text = self.prepare_d2l_text + else: + self.prepare_text = self.prepare_sample_text + # No test set during training + if args.no_evaluation: + print( + f"Size of the train set: {len(dataset)}. " + ) + + train_dataset = ConstantLengthDataset( + tokenizer, + dataset, + formatting_func=self.prepare_text, + infinite=True, + seq_length=args.max_seq_length, + # chars_per_token=chars_per_token, + ) + eval_dataset = None + else: + dataset = dataset.train_test_split( + test_size=args.train_test_split_ratio, seed=args.seed + ) + print( + f"Size of the train set: {len(dataset['train'])}. " + f" Size of the validation set: {len(dataset['test'])}") + + train_dataset = ConstantLengthDataset( + tokenizer, + dataset["train"], + formatting_func=self.prepare_text, + infinite=True, + seq_length=args.max_seq_length, + # chars_per_token=chars_per_token, + ) + + eval_dataset = ConstantLengthDataset( + tokenizer, + dataset["test"], + formatting_func=self.prepare_text, + infinite=False, + seq_length=args.max_seq_length, + # chars_per_token=chars_per_token, + ) + return {"train": train_dataset, "eval": eval_dataset}