-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SFT for D2L. DPO #101
Closed
Closed
SFT for D2L. DPO #101
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
eb4878f
auto rater, sample data and prompt engineering
llauraa23 5c9cb03
Merge branch 'CambioML:main' into main
llauraa23 6e9f880
merge conflict
llauraa23 1ab4f32
support supervised fine tuning on d2l.
llauraa23 9daf69c
resolve merge conflicts on gpu96
llauraa23 878f44e
Merge branch 'main' of https://github.com/llauraa23/pykoi
llauraa23 1428aba
support training multiple epochs in sft.
llauraa23 880fefc
implment evaluation of fine-tuned models with pykoi pipeline
llauraa23 58f946c
When evaluating the SFT model, store the questions and generated answ…
llauraa23 beeedaa
code cleanup for d2l demo. In SFT, make data collator, formatting fun…
llauraa23 04b9fa5
DPO training on d2l data. Version 0
llauraa23 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Binary file not shown.
Large diffs are not rendered by default.
Oops, something went wrong.
2,647 changes: 1,244 additions & 1,403 deletions
2,647
example/data_generation/immigration_gen_data.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
"""Demo for the supervised fine tuning. | ||
|
||
python -m example.rlhf.supervised_finetuning_demo_d2l | ||
""" | ||
|
||
from peft import LoraConfig | ||
from pykoi.chat import QuestionAnswerDatabase | ||
from pykoi.chat.db.constants import (QA_CSV_HEADER_ANSWER, QA_CSV_HEADER_ID, | ||
QA_CSV_HEADER_QUESTION, | ||
QA_CSV_HEADER_VOTE_STATUS) | ||
from pykoi.rlhf import RLHFConfig, SupervisedFinetuning | ||
|
||
# get data from local database | ||
qa_database = QuestionAnswerDatabase() | ||
my_data_pd = qa_database.retrieve_all_question_answers_as_pandas() | ||
my_data_pd = my_data_pd[ | ||
[ | ||
QA_CSV_HEADER_ID, | ||
QA_CSV_HEADER_QUESTION, | ||
QA_CSV_HEADER_ANSWER, | ||
QA_CSV_HEADER_VOTE_STATUS, | ||
] | ||
] | ||
|
||
# analyze the data | ||
print(my_data_pd) | ||
print("My local database has {} samples in total".format(my_data_pd.shape[0])) | ||
|
||
# run supervised finetuning | ||
config = RLHFConfig(base_model_path="mistralai/Mistral-7B-Instruct-v0.1", | ||
dataset_type="local_csv", dataset_name="data/chapter22_trnvalfromseed_data_processed.csv", | ||
train_test_split_ratio=0, # ratio for test set DH:TODO: COBINE TRAIN AND EVAL | ||
max_seq_length=896, | ||
per_device_eval_batch_size=1, | ||
log_freq=20, | ||
# dh: NOTE: 1 EPOCH iterates the dataset once. So log freq 20 means iterating 20 entries when training batch size = 1. | ||
# (i.e., log_freq = 0.12 epoch when the dataset has 166 entires). | ||
save_freq=40000, | ||
num_train_epochs=20, | ||
max_steps=-1, # if a positive number is given, it will override num_train_epochs | ||
device_map="auto", | ||
lora_config_rl=LoraConfig( | ||
r=512, | ||
lora_alpha=1024, | ||
lora_dropout=0.05, | ||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", ], # "gate_proj","up_proj","down_proj",], #"lm_head",], | ||
bias="none", | ||
task_type="CAUSAL_LM" | ||
), | ||
data_collator="DataCollatorForCompletionOnlyLM", | ||
no_evaluation=True, | ||
prepare_text="d2l", | ||
) | ||
rlhf_step1_sft = SupervisedFinetuning(config) | ||
rlhf_step1_sft.train_and_save("./models/rlhf_step1_sft") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from typing import Any, Dict, List, Union | ||
from transformers import DataCollatorForLanguageModeling | ||
import numpy as np | ||
|
||
|
||
class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling): | ||
def torch_call( | ||
self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: | ||
batch = super().torch_call(examples) | ||
|
||
# The prompt ends with the response key plus a newline. We encode this and then try to find it in the | ||
# sequence of tokens. This should just be a single token. | ||
RESPONSE_KEY = "### Response:" | ||
RESPONSE_KEY_NL = f"{RESPONSE_KEY}\n" | ||
response_token_ids = self.tokenizer.encode(RESPONSE_KEY_NL) | ||
|
||
labels = batch["labels"].clone() | ||
|
||
for i in range(len(examples)): | ||
|
||
response_token_ids_start_idx = None | ||
for idx in np.where( | ||
batch["labels"][i] == response_token_ids[0])[0]: | ||
response_token_ids_start_idx = idx | ||
break | ||
|
||
if response_token_ids_start_idx is None: | ||
raise RuntimeError( | ||
f'Could not find response key {response_token_ids} in token IDs {batch["labels"][i]}' | ||
) | ||
|
||
response_token_ids_end_idx = response_token_ids_start_idx + 1 | ||
|
||
# Make pytorch loss function ignore all tokens up through the end | ||
# of the response key | ||
labels[i, :response_token_ids_end_idx] = -100 | ||
|
||
batch["labels"] = labels | ||
|
||
return batch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,263 @@ | ||
# The code is adapted from Huggingface. | ||
# coding=utf-8 | ||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# Note: you need to install transformers from main to run this script. See https://huggingface.co/docs/transformers/installation#install-from-source | ||
# TODO: bump transformers version in requirements at next release. | ||
|
||
# 0. imports | ||
from dataclasses import dataclass, field | ||
from typing import Dict, Optional | ||
|
||
import torch | ||
from accelerate import PartialState | ||
from datasets import Dataset, load_dataset | ||
from peft import LoraConfig | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments | ||
|
||
from trl import DPOTrainer, is_xpu_available | ||
|
||
|
||
# Define and parse arguments. | ||
@dataclass | ||
class ScriptArguments: | ||
""" | ||
The arguments for the DPO training script. | ||
""" | ||
|
||
# data parameters | ||
beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"}) | ||
|
||
# training parameters | ||
model_name_or_path: Optional[str] = field(default="models/rlhf_step1_sft/", metadata={"help": "the model name"}) | ||
learning_rate: Optional[float] = field(default=1e-3, metadata={"help": "optimizer learning rate"}) | ||
#per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "batch size per device"}) | ||
# dh | ||
per_device_train_batch_size: Optional[int] = field(default=1, metadata={"help": "batch size per device"}) | ||
gradient_accumulation_steps: Optional[int] = field( | ||
default=1, metadata={"help": "the number of gradient accumulation steps"} | ||
) | ||
output_dir: Optional[str] = field(default="outputdpo", metadata={"help": "the output directory"}) | ||
fp16: Optional[bool] = field( | ||
default=False, metadata={"help": "Whether to activate fp16 mixed precision during training"} | ||
) | ||
bf16: Optional[bool] = field( | ||
default=False, metadata={"help": "Whether to activate bf16 mixed precision during training"} | ||
) | ||
max_length: Optional[int] = field(default=512, metadata={"help": "max length of each sample"}) | ||
max_prompt_length: Optional[int] = field(default=128, metadata={"help": "max length of each sample's prompt"}) | ||
max_target_length: Optional[int] = field( | ||
default=128, metadata={"help": "Only used for encoder decoder model. Max target of each sample's prompt"} | ||
) | ||
label_pad_token_id: Optional[int] = field(default=-100, metadata={"help": "label for non response tokens"}) | ||
#max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"}) | ||
max_steps: Optional[int] = field(default=500, metadata={"help": "max number of training steps"}) | ||
|
||
# lora parameters | ||
use_peft: Optional[bool] = field(default=True, metadata={"help": "Wether to use PEFT or not to train adapters"}) | ||
peft_lora_r: Optional[int] = field(default=64, metadata={"help": "the r parameter of the LoRA adapters"}) | ||
peft_lora_alpha: Optional[int] = field(default=16, metadata={"help": "the alpha parameter of the LoRA adapters"}) | ||
# instrumentation | ||
sanity_check: Optional[bool] = field(default=True, metadata={"help": "only train on 1000 samples"}) | ||
report_to: Optional[str] = field( | ||
default=None, | ||
metadata={ | ||
"help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,' | ||
'`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. ' | ||
'Use `"all"` to report to all integrations installed, `"none"` for no integrations.' | ||
}, | ||
) | ||
# debug argument for distributed training | ||
ignore_bias_buffers: Optional[bool] = field( | ||
default=False, | ||
metadata={ | ||
"help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See" | ||
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992" | ||
}, | ||
) | ||
gradient_checkpointing: Optional[bool] = field( | ||
default=False, metadata={"help": "Whether to use gradient checkpointing or no"} | ||
) | ||
gradient_checkpointing_kwargs: Optional[dict] = field( | ||
default=None, | ||
metadata={ | ||
"help": "key word arguments to be passed along `torch.utils.checkpoint.checkpoint` method - e.g. `use_reentrant=False`" | ||
}, | ||
) | ||
load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"}) | ||
load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"}) | ||
generate_during_eval: Optional[bool] = field(default=False, metadata={"help": "Generate during evaluation"}) | ||
|
||
|
||
def extract_anthropic_prompt(prompt_and_response): | ||
"""Extract the anthropic prompt from a prompt and response pair.""" | ||
search_term = "\n\nAssistant:" | ||
search_term_idx = prompt_and_response.rfind(search_term) | ||
assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'" | ||
return prompt_and_response[: search_term_idx + len(search_term)] | ||
|
||
|
||
def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset: | ||
"""Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format. | ||
|
||
The dataset is converted to a dictionary with the following structure: | ||
{ | ||
'prompt': List[str], | ||
'chosen': List[str], | ||
'rejected': List[str], | ||
} | ||
|
||
Prompts should be structured as follows: | ||
\n\nHuman: <prompt>\n\nAssistant: | ||
Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:. | ||
""" | ||
import pdb; pdb.set_trace() | ||
dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir) | ||
if sanity_check: | ||
dataset = dataset.select(range(min(len(dataset), 1000))) | ||
|
||
def split_prompt_and_responses(sample) -> Dict[str, str]: | ||
prompt = extract_anthropic_prompt(sample["chosen"]) | ||
return { | ||
"prompt": prompt, | ||
"chosen": sample["chosen"][len(prompt) :], | ||
"rejected": sample["rejected"][len(prompt) :], | ||
} | ||
|
||
return dataset.map(split_prompt_and_responses) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = HfArgumentParser(ScriptArguments) | ||
script_args = parser.parse_args_into_dataclasses()[0] | ||
|
||
if script_args.load_in_8bit and script_args.load_in_4bit: | ||
raise ValueError("You can't load the model in 8 bits and 4 bits at the same time") | ||
elif script_args.load_in_8bit or script_args.load_in_4bit: | ||
quantization_config = BitsAndBytesConfig( | ||
load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit | ||
) | ||
# Copy the model to each device | ||
device_map = ( | ||
{"": f"xpu:{PartialState().local_process_index}"} | ||
if is_xpu_available() | ||
else {"": PartialState().local_process_index} | ||
) | ||
torch_dtype = torch.bfloat16 | ||
else: | ||
# device_map = None | ||
# dh | ||
device_map = "auto" | ||
quantization_config = None | ||
torch_dtype = None | ||
|
||
# 1. load a pretrained model | ||
model = AutoModelForCausalLM.from_pretrained( | ||
script_args.model_name_or_path, | ||
device_map=device_map, | ||
quantization_config=quantization_config, | ||
torch_dtype=torch_dtype, | ||
) | ||
|
||
if script_args.ignore_bias_buffers: | ||
# torch distributed hack | ||
import pdb; pdb.set_trace() | ||
model._ddp_params_and_buffers_to_ignore = [ | ||
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool | ||
] | ||
import pdb; pdb.set_trace() | ||
|
||
if not script_args.use_peft: | ||
model_ref = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path) | ||
else: | ||
# If one uses PEFT, there is no need to load a reference model ## dh: TODO: CHECK THIS | ||
model_ref = None | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path) | ||
if tokenizer.pad_token is None: | ||
tokenizer.pad_token = tokenizer.eos_token | ||
|
||
# 2. Load the Anthropic Helpful-Harmless dataset | ||
# train_dataset = get_hh("train", sanity_check=script_args.sanity_check) | ||
|
||
# 3. Load evaluation dataset | ||
# eval_dataset = get_hh("test", sanity_check=script_args.sanity_check) | ||
|
||
# dh | ||
dataset= load_dataset("csv", data_files="data/rlhf_training_data_d2ai.csv", split="train") | ||
def feature_format(sample) -> Dict[str, str]: | ||
return { | ||
"prompt": sample["input"], | ||
"chosen": sample["chosen"], | ||
"rejected": sample["rejected"], | ||
} | ||
dataset = dataset.map(feature_format) | ||
train_eval = dataset.train_test_split(test_size=0.1) | ||
import pdb; pdb.set_trace() | ||
train_dataset = train_eval["train"] | ||
eval_dataset = train_eval["test"] | ||
|
||
|
||
# 4. initialize training arguments: | ||
training_args = TrainingArguments( | ||
per_device_train_batch_size=script_args.per_device_train_batch_size, | ||
max_steps=script_args.max_steps, | ||
remove_unused_columns=False, | ||
gradient_accumulation_steps=script_args.gradient_accumulation_steps, | ||
learning_rate=script_args.learning_rate, | ||
evaluation_strategy="steps", | ||
logging_first_step=True, | ||
logging_steps=10, # match results in blog post | ||
eval_steps=500, | ||
output_dir=script_args.output_dir, | ||
optim="rmsprop", | ||
warmup_steps=150, | ||
report_to=script_args.report_to, | ||
bf16=script_args.bf16, | ||
fp16=script_args.fp16, | ||
gradient_checkpointing=script_args.gradient_checkpointing, | ||
# TODO: uncomment that on the next transformers release | ||
# gradient_checkpointing_kwargs=script_args.gradient_checkpointing_kwargs, | ||
) | ||
|
||
if script_args.use_peft: | ||
peft_config = LoraConfig( | ||
r=script_args.peft_lora_r, | ||
lora_alpha=script_args.peft_lora_alpha, | ||
bias="none", | ||
task_type="CAUSAL_LM", | ||
) | ||
else: | ||
peft_config = None | ||
|
||
# 5. initialize the DPO trainer | ||
dpo_trainer = DPOTrainer( | ||
model, | ||
model_ref, | ||
args=training_args, | ||
beta=script_args.beta, | ||
train_dataset=train_dataset, | ||
eval_dataset=eval_dataset, | ||
tokenizer=tokenizer, | ||
max_length=script_args.max_length, | ||
max_target_length=script_args.max_target_length, | ||
max_prompt_length=script_args.max_prompt_length, | ||
generate_during_eval=script_args.generate_during_eval, | ||
peft_config=peft_config, | ||
) | ||
|
||
# 6. train | ||
import pdb; pdb.set_trace() | ||
dpo_trainer.train() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rename to
dpo.py