Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SFT for D2L. DPO #101

Closed
wants to merge 11 commits into from
4,698 changes: 4,698 additions & 0 deletions example/autorate/auto-rater.ipynb

Large diffs are not rendered by default.

Binary file added example/autorate/data/Chapter 5 Rome.docx
Binary file not shown.
497 changes: 497 additions & 0 deletions example/autorate/data/rome.txt

Large diffs are not rendered by default.

2,647 changes: 1,244 additions & 1,403 deletions example/data_generation/immigration_gen_data.ipynb

Large diffs are not rendered by default.

584 changes: 584 additions & 0 deletions example/rlhf/demo_supervised_finetuning_d2l_eval.ipynb

Large diffs are not rendered by default.

55 changes: 55 additions & 0 deletions example/rlhf/supervised_finetuning_demo_d2l.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Demo for the supervised fine tuning.

python -m example.rlhf.supervised_finetuning_demo_d2l
"""

from peft import LoraConfig
from pykoi.chat import QuestionAnswerDatabase
from pykoi.chat.db.constants import (QA_CSV_HEADER_ANSWER, QA_CSV_HEADER_ID,
QA_CSV_HEADER_QUESTION,
QA_CSV_HEADER_VOTE_STATUS)
from pykoi.rlhf import RLHFConfig, SupervisedFinetuning

# get data from local database
qa_database = QuestionAnswerDatabase()
my_data_pd = qa_database.retrieve_all_question_answers_as_pandas()
my_data_pd = my_data_pd[
[
QA_CSV_HEADER_ID,
QA_CSV_HEADER_QUESTION,
QA_CSV_HEADER_ANSWER,
QA_CSV_HEADER_VOTE_STATUS,
]
]

# analyze the data
print(my_data_pd)
print("My local database has {} samples in total".format(my_data_pd.shape[0]))

# run supervised finetuning
config = RLHFConfig(base_model_path="mistralai/Mistral-7B-Instruct-v0.1",
dataset_type="local_csv", dataset_name="data/chapter22_trnvalfromseed_data_processed.csv",
train_test_split_ratio=0, # ratio for test set DH:TODO: COBINE TRAIN AND EVAL
max_seq_length=896,
per_device_eval_batch_size=1,
log_freq=20,
# dh: NOTE: 1 EPOCH iterates the dataset once. So log freq 20 means iterating 20 entries when training batch size = 1.
# (i.e., log_freq = 0.12 epoch when the dataset has 166 entires).
save_freq=40000,
num_train_epochs=20,
max_steps=-1, # if a positive number is given, it will override num_train_epochs
device_map="auto",
lora_config_rl=LoraConfig(
r=512,
lora_alpha=1024,
lora_dropout=0.05,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", ], # "gate_proj","up_proj","down_proj",], #"lm_head",],
bias="none",
task_type="CAUSAL_LM"
),
data_collator="DataCollatorForCompletionOnlyLM",
no_evaluation=True,
prepare_text="d2l",
)
rlhf_step1_sft = SupervisedFinetuning(config)
rlhf_step1_sft.train_and_save("./models/rlhf_step1_sft")
14 changes: 14 additions & 0 deletions pykoi/rlhf/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from accelerate import Accelerator
from peft import LoraConfig, TaskType
# TODO: DH: num_train_epochs=20,


@dataclass
Expand Down Expand Up @@ -119,6 +120,7 @@ class RLHFConfig:
default="./rlhf_checkpoints",
metadata={"help": "Output directory for all model weights."},
)
num_train_epochs: Optional[int] = field(default=5, metadata={"help": "supervised fine tuning training epochs"})
log_freq: Optional[int] = field(default=1, metadata={"help": "Logging frequency."})
eval_freq: Optional[int] = field(
default=1000, metadata={"help": "Evaluation frequency."}
Expand Down Expand Up @@ -182,6 +184,18 @@ class RLHFConfig:
),
metadata={"help": "LoRA configuration."},
)
data_collator: Optional[str] = field(
default=None,
metadata={"help": "The name of data collator to use for training."},
)
no_evaluation: Optional[bool] = field(
default=False,
metadata={"help": "Whether to disable evaluations during training."},
)
prepare_text: Optional[str] = field(
default="sample",
metadata={"help": "How to prepare the text for the model."},
)

# Step 2 reward modeling parameters
reward_model_path: Optional[str] = field(
Expand Down
40 changes: 40 additions & 0 deletions pykoi/rlhf/customize_data_collator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Any, Dict, List, Union
from transformers import DataCollatorForLanguageModeling
import numpy as np


class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
def torch_call(
self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
batch = super().torch_call(examples)

# The prompt ends with the response key plus a newline. We encode this and then try to find it in the
# sequence of tokens. This should just be a single token.
RESPONSE_KEY = "### Response:"
RESPONSE_KEY_NL = f"{RESPONSE_KEY}\n"
response_token_ids = self.tokenizer.encode(RESPONSE_KEY_NL)

labels = batch["labels"].clone()

for i in range(len(examples)):

response_token_ids_start_idx = None
for idx in np.where(
batch["labels"][i] == response_token_ids[0])[0]:
response_token_ids_start_idx = idx
break

if response_token_ids_start_idx is None:
raise RuntimeError(
f'Could not find response key {response_token_ids} in token IDs {batch["labels"][i]}'
)

response_token_ids_end_idx = response_token_ids_start_idx + 1

# Make pytorch loss function ignore all tokens up through the end
# of the response key
labels[i, :response_token_ids_end_idx] = -100

batch["labels"] = labels

return batch
263 changes: 263 additions & 0 deletions pykoi/rlhf/rl_finetuning_dpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
# The code is adapted from Huggingface.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename to dpo.py

# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Note: you need to install transformers from main to run this script. See https://huggingface.co/docs/transformers/installation#install-from-source
# TODO: bump transformers version in requirements at next release.

# 0. imports
from dataclasses import dataclass, field
from typing import Dict, Optional

import torch
from accelerate import PartialState
from datasets import Dataset, load_dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments

from trl import DPOTrainer, is_xpu_available


# Define and parse arguments.
@dataclass
class ScriptArguments:
"""
The arguments for the DPO training script.
"""

# data parameters
beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})

# training parameters
model_name_or_path: Optional[str] = field(default="models/rlhf_step1_sft/", metadata={"help": "the model name"})
learning_rate: Optional[float] = field(default=1e-3, metadata={"help": "optimizer learning rate"})
#per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "batch size per device"})
# dh
per_device_train_batch_size: Optional[int] = field(default=1, metadata={"help": "batch size per device"})
gradient_accumulation_steps: Optional[int] = field(
default=1, metadata={"help": "the number of gradient accumulation steps"}
)
output_dir: Optional[str] = field(default="outputdpo", metadata={"help": "the output directory"})
fp16: Optional[bool] = field(
default=False, metadata={"help": "Whether to activate fp16 mixed precision during training"}
)
bf16: Optional[bool] = field(
default=False, metadata={"help": "Whether to activate bf16 mixed precision during training"}
)
max_length: Optional[int] = field(default=512, metadata={"help": "max length of each sample"})
max_prompt_length: Optional[int] = field(default=128, metadata={"help": "max length of each sample's prompt"})
max_target_length: Optional[int] = field(
default=128, metadata={"help": "Only used for encoder decoder model. Max target of each sample's prompt"}
)
label_pad_token_id: Optional[int] = field(default=-100, metadata={"help": "label for non response tokens"})
#max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"})
max_steps: Optional[int] = field(default=500, metadata={"help": "max number of training steps"})

# lora parameters
use_peft: Optional[bool] = field(default=True, metadata={"help": "Wether to use PEFT or not to train adapters"})
peft_lora_r: Optional[int] = field(default=64, metadata={"help": "the r parameter of the LoRA adapters"})
peft_lora_alpha: Optional[int] = field(default=16, metadata={"help": "the alpha parameter of the LoRA adapters"})
# instrumentation
sanity_check: Optional[bool] = field(default=True, metadata={"help": "only train on 1000 samples"})
report_to: Optional[str] = field(
default=None,
metadata={
"help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,'
'`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. '
'Use `"all"` to report to all integrations installed, `"none"` for no integrations.'
},
)
# debug argument for distributed training
ignore_bias_buffers: Optional[bool] = field(
default=False,
metadata={
"help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
},
)
gradient_checkpointing: Optional[bool] = field(
default=False, metadata={"help": "Whether to use gradient checkpointing or no"}
)
gradient_checkpointing_kwargs: Optional[dict] = field(
default=None,
metadata={
"help": "key word arguments to be passed along `torch.utils.checkpoint.checkpoint` method - e.g. `use_reentrant=False`"
},
)
load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"})
load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"})
generate_during_eval: Optional[bool] = field(default=False, metadata={"help": "Generate during evaluation"})


def extract_anthropic_prompt(prompt_and_response):
"""Extract the anthropic prompt from a prompt and response pair."""
search_term = "\n\nAssistant:"
search_term_idx = prompt_and_response.rfind(search_term)
assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'"
return prompt_and_response[: search_term_idx + len(search_term)]


def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset:
"""Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.

The dataset is converted to a dictionary with the following structure:
{
'prompt': List[str],
'chosen': List[str],
'rejected': List[str],
}

Prompts should be structured as follows:
\n\nHuman: <prompt>\n\nAssistant:
Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:.
"""
import pdb; pdb.set_trace()
dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir)
if sanity_check:
dataset = dataset.select(range(min(len(dataset), 1000)))

def split_prompt_and_responses(sample) -> Dict[str, str]:
prompt = extract_anthropic_prompt(sample["chosen"])
return {
"prompt": prompt,
"chosen": sample["chosen"][len(prompt) :],
"rejected": sample["rejected"][len(prompt) :],
}

return dataset.map(split_prompt_and_responses)


if __name__ == "__main__":
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]

if script_args.load_in_8bit and script_args.load_in_4bit:
raise ValueError("You can't load the model in 8 bits and 4 bits at the same time")
elif script_args.load_in_8bit or script_args.load_in_4bit:
quantization_config = BitsAndBytesConfig(
load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit
)
# Copy the model to each device
device_map = (
{"": f"xpu:{PartialState().local_process_index}"}
if is_xpu_available()
else {"": PartialState().local_process_index}
)
torch_dtype = torch.bfloat16
else:
# device_map = None
# dh
device_map = "auto"
quantization_config = None
torch_dtype = None

# 1. load a pretrained model
model = AutoModelForCausalLM.from_pretrained(
script_args.model_name_or_path,
device_map=device_map,
quantization_config=quantization_config,
torch_dtype=torch_dtype,
)

if script_args.ignore_bias_buffers:
# torch distributed hack
import pdb; pdb.set_trace()
model._ddp_params_and_buffers_to_ignore = [
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
]
import pdb; pdb.set_trace()

if not script_args.use_peft:
model_ref = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path)
else:
# If one uses PEFT, there is no need to load a reference model ## dh: TODO: CHECK THIS
model_ref = None

tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

# 2. Load the Anthropic Helpful-Harmless dataset
# train_dataset = get_hh("train", sanity_check=script_args.sanity_check)

# 3. Load evaluation dataset
# eval_dataset = get_hh("test", sanity_check=script_args.sanity_check)

# dh
dataset= load_dataset("csv", data_files="data/rlhf_training_data_d2ai.csv", split="train")
def feature_format(sample) -> Dict[str, str]:
return {
"prompt": sample["input"],
"chosen": sample["chosen"],
"rejected": sample["rejected"],
}
dataset = dataset.map(feature_format)
train_eval = dataset.train_test_split(test_size=0.1)
import pdb; pdb.set_trace()
train_dataset = train_eval["train"]
eval_dataset = train_eval["test"]


# 4. initialize training arguments:
training_args = TrainingArguments(
per_device_train_batch_size=script_args.per_device_train_batch_size,
max_steps=script_args.max_steps,
remove_unused_columns=False,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
learning_rate=script_args.learning_rate,
evaluation_strategy="steps",
logging_first_step=True,
logging_steps=10, # match results in blog post
eval_steps=500,
output_dir=script_args.output_dir,
optim="rmsprop",
warmup_steps=150,
report_to=script_args.report_to,
bf16=script_args.bf16,
fp16=script_args.fp16,
gradient_checkpointing=script_args.gradient_checkpointing,
# TODO: uncomment that on the next transformers release
# gradient_checkpointing_kwargs=script_args.gradient_checkpointing_kwargs,
)

if script_args.use_peft:
peft_config = LoraConfig(
r=script_args.peft_lora_r,
lora_alpha=script_args.peft_lora_alpha,
bias="none",
task_type="CAUSAL_LM",
)
else:
peft_config = None

# 5. initialize the DPO trainer
dpo_trainer = DPOTrainer(
model,
model_ref,
args=training_args,
beta=script_args.beta,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
max_length=script_args.max_length,
max_target_length=script_args.max_target_length,
max_prompt_length=script_args.max_prompt_length,
generate_during_eval=script_args.generate_during_eval,
peft_config=peft_config,
)

# 6. train
import pdb; pdb.set_trace()
dpo_trainer.train()
Loading