From 1ab4f325279b4fb438c187988decda3752063187 Mon Sep 17 00:00:00 2001 From: llauraa23 Date: Tue, 9 Jan 2024 23:48:17 +0000 Subject: [PATCH] support supervised fine tuning on d2l. execute with "python -m example.rlhf.supervised_finetuning_d2l" --- example/rlhf/supervised_finetuning_d2l.py | 45 ++++++++++++ pykoi/rlhf/config.py | 1 + pykoi/rlhf/customize_data_collator.py | 35 +++++++++ pykoi/rlhf/supervised_finetuning.py | 87 +++++++++++++++++++++-- 4 files changed, 164 insertions(+), 4 deletions(-) create mode 100644 example/rlhf/supervised_finetuning_d2l.py create mode 100644 pykoi/rlhf/customize_data_collator.py diff --git a/example/rlhf/supervised_finetuning_d2l.py b/example/rlhf/supervised_finetuning_d2l.py new file mode 100644 index 0000000..0e50a85 --- /dev/null +++ b/example/rlhf/supervised_finetuning_d2l.py @@ -0,0 +1,45 @@ +"""Demo for the supervised fine tuning. + +python -m example.rlhf.supervised_finetuning_demo +""" + +from pykoi.chat import QuestionAnswerDatabase +from pykoi.chat.db.constants import (QA_CSV_HEADER_ANSWER, QA_CSV_HEADER_ID, + QA_CSV_HEADER_QUESTION, + QA_CSV_HEADER_VOTE_STATUS) +from pykoi.rlhf import RLHFConfig, SupervisedFinetuning + +# get data from local database +qa_database = QuestionAnswerDatabase() +my_data_pd = qa_database.retrieve_all_question_answers_as_pandas() +my_data_pd = my_data_pd[ + [ + QA_CSV_HEADER_ID, + QA_CSV_HEADER_QUESTION, + QA_CSV_HEADER_ANSWER, + QA_CSV_HEADER_VOTE_STATUS, + ] +] + +# analyze the data +print(my_data_pd) +print("My local database has {} samples in total".format(my_data_pd.shape[0])) + +# run supervised finetuning +from peft import LoraConfig +config = RLHFConfig(base_model_path="mistralai/Mistral-7B-Instruct-v0.1", + dataset_type="local_csv", dataset_name="data/chapter22_trnvalfromseed_data_processed.csv", + train_test_split_ratio=0.1, + max_seq_length=896, + per_device_eval_batch_size = 1, + lora_config_rl = LoraConfig( + r=512, + lora_alpha=1024, + lora_dropout=0.05, + target_modules=["q_proj","k_proj","v_proj","o_proj",], # "gate_proj","up_proj","down_proj",], #"lm_head",], + bias="none", + task_type="CAUSAL_LM" + ), + ) +rlhf_step1_sft = SupervisedFinetuning(config) +rlhf_step1_sft.train_and_save("./models/rlhf_step1_sft") diff --git a/pykoi/rlhf/config.py b/pykoi/rlhf/config.py index e7721f1..c34d68e 100644 --- a/pykoi/rlhf/config.py +++ b/pykoi/rlhf/config.py @@ -5,6 +5,7 @@ from accelerate import Accelerator from peft import LoraConfig, TaskType +# TODO: DH: num_train_epochs=20, @dataclass diff --git a/pykoi/rlhf/customize_data_collator.py b/pykoi/rlhf/customize_data_collator.py new file mode 100644 index 0000000..5cc8c1e --- /dev/null +++ b/pykoi/rlhf/customize_data_collator.py @@ -0,0 +1,35 @@ +from typing import Any, Dict, List, Tuple, Union +from transformers import DataCollatorForLanguageModeling +import numpy as np +class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling): + def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: + batch = super().torch_call(examples) + + # The prompt ends with the response key plus a newline. We encode this and then try to find it in the + # sequence of tokens. This should just be a single token. + RESPONSE_KEY = "### Response:" + RESPONSE_KEY_NL = f"{RESPONSE_KEY}\n" + response_token_ids = self.tokenizer.encode(RESPONSE_KEY_NL) + + labels = batch["labels"].clone() + + for i in range(len(examples)): + + response_token_ids_start_idx = None + for idx in np.where(batch["labels"][i] == response_token_ids[0])[0]: + response_token_ids_start_idx = idx + break + + if response_token_ids_start_idx is None: + raise RuntimeError( + f'Could not find response key {response_token_ids} in token IDs {batch["labels"][i]}' + ) + + response_token_ids_end_idx = response_token_ids_start_idx + 1 + + # Make pytorch loss function ignore all tokens up through the end of the response key + labels[i, :response_token_ids_end_idx] = -100 + + batch["labels"] = labels + + return batch \ No newline at end of file diff --git a/pykoi/rlhf/supervised_finetuning.py b/pykoi/rlhf/supervised_finetuning.py index 7a58a9f..c5e8ed6 100644 --- a/pykoi/rlhf/supervised_finetuning.py +++ b/pykoi/rlhf/supervised_finetuning.py @@ -20,7 +20,7 @@ from pykoi.rlhf.config import RLHFConfig from pykoi.telemetry.events import SFTStartEvent, SFTStopEvent from pykoi.telemetry.telemetry import Telemetry - +from pykoi.rlhf.customize_data_collator import DataCollatorForCompletionOnlyLM class SupervisedFinetuning: """ @@ -48,6 +48,13 @@ def __init__(self, rlhf_config: RLHFConfig, enable_telemetry: bool = True) -> No self._telemetry = Telemetry(enable_telemetry) self._rlhf_config = rlhf_config self.tokenizer = AutoTokenizer.from_pretrained(rlhf_config.base_model_path) + # dh: add special tokens to tokenizer + self.tokenizer.pad_token = self.tokenizer.eos_token + END_KEY = "### End" + INSTRUCTION_KEY = "### Instruction:" + RESPONSE_KEY = "### Response:" + RESPONSE_KEY_NL = f"{RESPONSE_KEY}\n" + self.tokenizer.add_special_tokens({"additional_special_tokens": [END_KEY, INSTRUCTION_KEY, RESPONSE_KEY_NL]}) self.num_proc = ( self._rlhf_config.num_workers if not self._rlhf_config.streaming else None ) @@ -83,13 +90,23 @@ def __init__(self, rlhf_config: RLHFConfig, enable_telemetry: bool = True) -> No load_in_8bit=self._rlhf_config.load_in_8bit, device_map=self._rlhf_config.device_map, ) + # resize the token embeddings to include the added special tokens + self.model.resize_token_embeddings(len(self.tokenizer)) + + # dh: try the customized data collator that only predicts the answer part + data_collator = DataCollatorForCompletionOnlyLM( + tokenizer=self.tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8 + ) + self.trainer = SFTTrainer( model=self.model, args=self.training_args, train_dataset=self.dataset["train"], eval_dataset=self.dataset["eval"], - peft_config=self._rlhf_config.lora_config_rl, + peft_config=self._rlhf_config.lora_config_rl, ## TODO: DH: LoraConfig MAY BE IGNORED IF USING FROM_PRETRAINED packing=True, + data_collator=data_collator, + dataset_text_field="text", ) def train(self): @@ -103,6 +120,8 @@ def load_lora( base_model_path: Optional[str] = None, lora_model_path: Optional[str] = None, ): + #import pdb; pdb.set_trace() + # dh: not used if base_model_path is None: base_model_path = self._rlhf_config.base_model_path @@ -163,6 +182,65 @@ def prepare_sample_text(self, example): f" Answer: {example[self._rlhf_config.answer_title]}" ) return text + + + def prepare_d2l_text(self, example): + """Prepare the text from a sample of the d2l dataset .""" + INTRO_BLURB = ( + "Below is an instruction that describes a task. Write a response that appropriately completes the request." + ) + INSTRUCTION_KEY = "### Instruction:" + INPUT_KEY = "Input:" + RESPONSE_KEY = "### Response:" + END_KEY = "### End" + RESPONSE_KEY_NL = f"{RESPONSE_KEY}\n" + DEFAULT_SEED = 42 + + # This is a training prompt that does not contain an input string. The instruction by itself has enough information + # to respond. For example, the instruction might ask for the year a historic figure was born. + PROMPT_NO_INPUT_FORMAT = """{intro} + {instruction_key} + {instruction} + {response_key} + {response} + {end_key}""".format( + intro=INTRO_BLURB, + instruction_key=INSTRUCTION_KEY, + instruction="{instruction}", + response_key=RESPONSE_KEY, + response="{response}", + end_key=END_KEY, + ) + + # This is a training prompt that contains an input string that serves as context for the instruction. For example, + # the input might be a passage from Wikipedia and the intruction is to extract some information from it. + PROMPT_WITH_INPUT_FORMAT = """{intro} + {instruction_key} + {instruction} + {input_key} + {input} + {response_key} + {response} + {end_key}""".format( + intro=INTRO_BLURB, + instruction_key=INSTRUCTION_KEY, + instruction="{instruction}", + input_key=INPUT_KEY, + input="{input}", + response_key=RESPONSE_KEY, + response="{response}", + end_key=END_KEY, + ) + + context = example.get("context") + if context: + text = PROMPT_WITH_INPUT_FORMAT.format(instruction=example["instruction"], response=example["response"], input=context) + else: + text = PROMPT_NO_INPUT_FORMAT.format(instruction=example["instruction"], response=example["instruction"]) + + + + return text def create_datasets(self, tokenizer, args): if args.dataset_type == "local_db": @@ -181,6 +259,7 @@ def create_datasets(self, tokenizer, args): elif args.dataset_type == "local_csv": dataset = load_dataset("csv", data_files=args.dataset_name) dataset = dataset[args.split] # Convert DatasetDict to Dataset + dataset2 = load_dataset("csv", data_files=args.dataset_name, split='train[:10%]') elif args.dataset_type == "huggingface": dataset = load_dataset( args.dataset_name, @@ -208,7 +287,7 @@ def create_datasets(self, tokenizer, args): train_dataset = ConstantLengthDataset( tokenizer, dataset["train"], - formatting_func=self.prepare_sample_text, + formatting_func=self.prepare_d2l_text, infinite=True, seq_length=args.max_seq_length, # chars_per_token=chars_per_token, @@ -216,7 +295,7 @@ def create_datasets(self, tokenizer, args): eval_dataset = ConstantLengthDataset( tokenizer, dataset["test"], - formatting_func=self.prepare_sample_text, + formatting_func=self.prepare_d2l_text, infinite=False, seq_length=args.max_seq_length, # chars_per_token=chars_per_token,