Skip to content

Commit

Permalink
support supervised fine tuning on d2l.
Browse files Browse the repository at this point in the history
execute with "python -m example.rlhf.supervised_finetuning_d2l"
  • Loading branch information
llauraa23 committed Jan 9, 2024
1 parent 5c9cb03 commit 1ab4f32
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 4 deletions.
45 changes: 45 additions & 0 deletions example/rlhf/supervised_finetuning_d2l.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Demo for the supervised fine tuning.
python -m example.rlhf.supervised_finetuning_demo
"""

from pykoi.chat import QuestionAnswerDatabase
from pykoi.chat.db.constants import (QA_CSV_HEADER_ANSWER, QA_CSV_HEADER_ID,
QA_CSV_HEADER_QUESTION,
QA_CSV_HEADER_VOTE_STATUS)
from pykoi.rlhf import RLHFConfig, SupervisedFinetuning

# get data from local database
qa_database = QuestionAnswerDatabase()
my_data_pd = qa_database.retrieve_all_question_answers_as_pandas()
my_data_pd = my_data_pd[
[
QA_CSV_HEADER_ID,
QA_CSV_HEADER_QUESTION,
QA_CSV_HEADER_ANSWER,
QA_CSV_HEADER_VOTE_STATUS,
]
]

# analyze the data
print(my_data_pd)
print("My local database has {} samples in total".format(my_data_pd.shape[0]))

# run supervised finetuning
from peft import LoraConfig
config = RLHFConfig(base_model_path="mistralai/Mistral-7B-Instruct-v0.1",
dataset_type="local_csv", dataset_name="data/chapter22_trnvalfromseed_data_processed.csv",
train_test_split_ratio=0.1,
max_seq_length=896,
per_device_eval_batch_size = 1,
lora_config_rl = LoraConfig(
r=512,
lora_alpha=1024,
lora_dropout=0.05,
target_modules=["q_proj","k_proj","v_proj","o_proj",], # "gate_proj","up_proj","down_proj",], #"lm_head",],
bias="none",
task_type="CAUSAL_LM"
),
)
rlhf_step1_sft = SupervisedFinetuning(config)
rlhf_step1_sft.train_and_save("./models/rlhf_step1_sft")
1 change: 1 addition & 0 deletions pykoi/rlhf/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from accelerate import Accelerator
from peft import LoraConfig, TaskType
# TODO: DH: num_train_epochs=20,


@dataclass
Expand Down
35 changes: 35 additions & 0 deletions pykoi/rlhf/customize_data_collator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Any, Dict, List, Tuple, Union
from transformers import DataCollatorForLanguageModeling
import numpy as np
class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
batch = super().torch_call(examples)

# The prompt ends with the response key plus a newline. We encode this and then try to find it in the
# sequence of tokens. This should just be a single token.
RESPONSE_KEY = "### Response:"
RESPONSE_KEY_NL = f"{RESPONSE_KEY}\n"
response_token_ids = self.tokenizer.encode(RESPONSE_KEY_NL)

labels = batch["labels"].clone()

for i in range(len(examples)):

response_token_ids_start_idx = None
for idx in np.where(batch["labels"][i] == response_token_ids[0])[0]:
response_token_ids_start_idx = idx
break

if response_token_ids_start_idx is None:
raise RuntimeError(
f'Could not find response key {response_token_ids} in token IDs {batch["labels"][i]}'
)

response_token_ids_end_idx = response_token_ids_start_idx + 1

# Make pytorch loss function ignore all tokens up through the end of the response key
labels[i, :response_token_ids_end_idx] = -100

batch["labels"] = labels

return batch
87 changes: 83 additions & 4 deletions pykoi/rlhf/supervised_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pykoi.rlhf.config import RLHFConfig
from pykoi.telemetry.events import SFTStartEvent, SFTStopEvent
from pykoi.telemetry.telemetry import Telemetry

from pykoi.rlhf.customize_data_collator import DataCollatorForCompletionOnlyLM

class SupervisedFinetuning:
"""
Expand Down Expand Up @@ -48,6 +48,13 @@ def __init__(self, rlhf_config: RLHFConfig, enable_telemetry: bool = True) -> No
self._telemetry = Telemetry(enable_telemetry)
self._rlhf_config = rlhf_config
self.tokenizer = AutoTokenizer.from_pretrained(rlhf_config.base_model_path)
# dh: add special tokens to tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token
END_KEY = "### End"
INSTRUCTION_KEY = "### Instruction:"
RESPONSE_KEY = "### Response:"
RESPONSE_KEY_NL = f"{RESPONSE_KEY}\n"
self.tokenizer.add_special_tokens({"additional_special_tokens": [END_KEY, INSTRUCTION_KEY, RESPONSE_KEY_NL]})
self.num_proc = (
self._rlhf_config.num_workers if not self._rlhf_config.streaming else None
)
Expand Down Expand Up @@ -83,13 +90,23 @@ def __init__(self, rlhf_config: RLHFConfig, enable_telemetry: bool = True) -> No
load_in_8bit=self._rlhf_config.load_in_8bit,
device_map=self._rlhf_config.device_map,
)
# resize the token embeddings to include the added special tokens
self.model.resize_token_embeddings(len(self.tokenizer))

# dh: try the customized data collator that only predicts the answer part
data_collator = DataCollatorForCompletionOnlyLM(
tokenizer=self.tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8
)

self.trainer = SFTTrainer(
model=self.model,
args=self.training_args,
train_dataset=self.dataset["train"],
eval_dataset=self.dataset["eval"],
peft_config=self._rlhf_config.lora_config_rl,
peft_config=self._rlhf_config.lora_config_rl, ## TODO: DH: LoraConfig MAY BE IGNORED IF USING FROM_PRETRAINED
packing=True,
data_collator=data_collator,
dataset_text_field="text",
)

def train(self):
Expand All @@ -103,6 +120,8 @@ def load_lora(
base_model_path: Optional[str] = None,
lora_model_path: Optional[str] = None,
):
#import pdb; pdb.set_trace()
# dh: not used
if base_model_path is None:
base_model_path = self._rlhf_config.base_model_path

Expand Down Expand Up @@ -163,6 +182,65 @@ def prepare_sample_text(self, example):
f" Answer: {example[self._rlhf_config.answer_title]}"
)
return text


def prepare_d2l_text(self, example):
"""Prepare the text from a sample of the d2l dataset ."""
INTRO_BLURB = (
"Below is an instruction that describes a task. Write a response that appropriately completes the request."
)
INSTRUCTION_KEY = "### Instruction:"
INPUT_KEY = "Input:"
RESPONSE_KEY = "### Response:"
END_KEY = "### End"
RESPONSE_KEY_NL = f"{RESPONSE_KEY}\n"
DEFAULT_SEED = 42

# This is a training prompt that does not contain an input string. The instruction by itself has enough information
# to respond. For example, the instruction might ask for the year a historic figure was born.
PROMPT_NO_INPUT_FORMAT = """{intro}
{instruction_key}
{instruction}
{response_key}
{response}
{end_key}""".format(
intro=INTRO_BLURB,
instruction_key=INSTRUCTION_KEY,
instruction="{instruction}",
response_key=RESPONSE_KEY,
response="{response}",
end_key=END_KEY,
)

# This is a training prompt that contains an input string that serves as context for the instruction. For example,
# the input might be a passage from Wikipedia and the intruction is to extract some information from it.
PROMPT_WITH_INPUT_FORMAT = """{intro}
{instruction_key}
{instruction}
{input_key}
{input}
{response_key}
{response}
{end_key}""".format(
intro=INTRO_BLURB,
instruction_key=INSTRUCTION_KEY,
instruction="{instruction}",
input_key=INPUT_KEY,
input="{input}",
response_key=RESPONSE_KEY,
response="{response}",
end_key=END_KEY,
)

context = example.get("context")
if context:
text = PROMPT_WITH_INPUT_FORMAT.format(instruction=example["instruction"], response=example["response"], input=context)
else:
text = PROMPT_NO_INPUT_FORMAT.format(instruction=example["instruction"], response=example["instruction"])



return text

def create_datasets(self, tokenizer, args):
if args.dataset_type == "local_db":
Expand All @@ -181,6 +259,7 @@ def create_datasets(self, tokenizer, args):
elif args.dataset_type == "local_csv":
dataset = load_dataset("csv", data_files=args.dataset_name)
dataset = dataset[args.split] # Convert DatasetDict to Dataset
dataset2 = load_dataset("csv", data_files=args.dataset_name, split='train[:10%]')
elif args.dataset_type == "huggingface":
dataset = load_dataset(
args.dataset_name,
Expand Down Expand Up @@ -208,15 +287,15 @@ def create_datasets(self, tokenizer, args):
train_dataset = ConstantLengthDataset(
tokenizer,
dataset["train"],
formatting_func=self.prepare_sample_text,
formatting_func=self.prepare_d2l_text,
infinite=True,
seq_length=args.max_seq_length,
# chars_per_token=chars_per_token,
)
eval_dataset = ConstantLengthDataset(
tokenizer,
dataset["test"],
formatting_func=self.prepare_sample_text,
formatting_func=self.prepare_d2l_text,
infinite=False,
seq_length=args.max_seq_length,
# chars_per_token=chars_per_token,
Expand Down

0 comments on commit 1ab4f32

Please sign in to comment.