Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add scripts/run_kto.py #165

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 257 additions & 0 deletions scripts/run_kto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import random
import sys

import torch
import transformers
from transformers import AutoModelForCausalLM, set_seed

from alignment import (
DataArguments,
H4ArgumentParser,
ModelArguments,
apply_chat_template,
decontaminate_humaneval,
get_checkpoint,
get_datasets,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
get_tokenizer,
is_adapter_model,
)
from peft import PeftConfig, PeftModel
from trl import KTOConfig, KTOTrainer, setup_chat_format


logger = logging.getLogger(__name__)


def main():
parser = H4ArgumentParser((ModelArguments, DataArguments, KTOConfig))
model_args, data_args, training_args = parser.parse()

#######
# Setup
#######
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()

# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Model parameters {model_args}")
logger.info(f"Data parameters {data_args}")
logger.info(f"Training/evaluation parameters {training_args}")

# Check for last checkpoint
last_checkpoint = get_checkpoint(training_args)
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")

# Set seed for reproducibility
set_seed(training_args.seed)

###############
# Load datasets
###############
raw_datasets = get_datasets(
data_args,
splits=data_args.dataset_splits,
configs=data_args.dataset_configs,
columns_to_keep=["prompt", "completion", "label"],
)
logger.info(
f"Training on the following splits: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
)
column_names = list(raw_datasets["train"].features)

################
# Load tokenizer
################
data_args.truncation_side = "left" # Truncate from left to ensure we don't lose labels in final turn
tokenizer = get_tokenizer(model_args, data_args)

#####################################################
# Load model (required here to setup the chat format)
#####################################################
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)

model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
use_flash_attention_2=model_args.use_flash_attention_2,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)

if is_adapter_model(model_args.model_name_or_path, model_args.model_revision):
logger.info(f"Loading adapter for {model_args.model_name_or_path=}")

peft_config = PeftConfig.from_pretrained(model_args.model_name_or_path, revision=model_args.model_revision)
base_model = AutoModelForCausalLM.from_pretrained(
peft_config.base_model_name_or_path,
**model_kwargs,
)
model = PeftModel.from_pretrained(
base_model,
model_args.model_name_or_path,
revision=model_args.model_revision,
)
else:
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
**model_kwargs,
)

# For ChatML we need to add special tokens and resize the embedding layer
if "<|im_start|>" in tokenizer.chat_template:
model, tokenizer = setup_chat_format(model, tokenizer)

#####################
# Apply chat template
#####################
raw_datasets = raw_datasets.map(
apply_chat_template,
fn_kwargs={
"tokenizer": tokenizer,
"task": "kto",
"auto_insert_empty_system_msg": data_args.auto_insert_empty_system_msg,
},
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
desc="Formatting prompt-completion pairs with prompt template",
)

##########################
# Decontaminate benchmarks
##########################
num_raw_train_samples = len(raw_datasets["train"])
raw_datasets = raw_datasets.filter(
decontaminate_humaneval,
fn_kwargs={"text_column": "completion"},
batched=True,
batch_size=10_000,
num_proc=1,
desc="Decontaminating HumanEval samples",
)
num_filtered_train_samples = num_raw_train_samples - len(raw_datasets["train"])
logger.info(
f"Decontaminated {num_filtered_train_samples} ({num_filtered_train_samples/num_raw_train_samples * 100:.2f}%) samples from the training set."
)

# Log a few random samples from the training set:
for index in random.sample(range(len(raw_datasets["train"])), 3):
logger.info(f"Prompt sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['prompt']}")
logger.info(
f"Completion sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['completion']}"
)

ref_model = None
if not model_args.use_peft:
logger.info(f"Loading reference model for {model_args.model_name_or_path=}")
ref_model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
**model_kwargs,
)

########################
# Instantiate KTOTrainer
########################
trainer = KTOTrainer(
model,
ref_model,
args=training_args,
train_dataset=raw_datasets["train"],
eval_dataset=raw_datasets["test"] if "test" in raw_datasets else None,
tokenizer=tokenizer,
peft_config=get_peft_config(model_args) if model_args.use_peft else None,
)

###############
# Training loop
###############
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
metrics["train_samples"] = len(raw_datasets["train"])
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()

logger.info("*** Training complete ***")

##################################
# Save model and create model card
##################################
logger.info("*** Save model ***")
trainer.save_model(training_args.output_dir)
logger.info(f"Model saved to {training_args.output_dir}")

# Save everything else on main process
kwargs = {
"finetuned_from": model_args.model_name_or_path,
"dataset": list(data_args.dataset_mixer.keys()),
"dataset_tags": list(data_args.dataset_mixer.keys()),
"tags": ["alignment-handbook"],
}
if trainer.accelerator.is_main_process:
trainer.create_model_card(**kwargs)
# Restore k,v cache for fast inference
trainer.model.config.use_cache = True
trainer.model.config.save_pretrained(training_args.output_dir)

##########
# Evaluate
##########
if training_args.do_eval and "test" in raw_datasets:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate()
metrics["eval_samples"] = len(raw_datasets["test"])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

if training_args.push_to_hub is True:
logger.info("Pushing to hub...")
trainer.push_to_hub(**kwargs)

logger.info("*** Training complete! ***")


if __name__ == "__main__":
main()
17 changes: 15 additions & 2 deletions src/alignment/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def maybe_insert_system_message(messages, tokenizer):
def apply_chat_template(
example,
tokenizer,
task: Literal["sft", "generation", "rm", "dpo"],
task: Literal["sft", "generation", "rm", "dpo", "kto"],
auto_insert_empty_system_msg: bool = True,
):
if task in ["sft", "generation"]:
Expand Down Expand Up @@ -101,9 +101,22 @@ def apply_chat_template(
f"Could not format example as dialogue for `{task}` task! Require either the "
f"`[chosen, rejected]` or `[prompt, chosen, rejected]` keys but found {list(example.keys())}"
)
elif task == "kto":
if all(k in example.keys() for k in ("prompt", "completion", "label")):
if not is_openai_format(example["prompt"]) or not is_openai_format(example["completion"]):
raise ValueError(
f"Could not format example as dialogue for `{task}` task! Require OpenAI format for all messages"
)
example["prompt"] = tokenizer.apply_chat_template(example["prompt"], tokenize=False)
example["completion"] = tokenizer.apply_chat_template(example["completion"], tokenize=False)
else:
raise ValueError(
f"Could not format example as dialogue for `{task}` task! Requires the keys `[prompt, completion, label]`"
f" but found {list(example.keys())} instead."
)
else:
raise ValueError(
f"Task {task} not supported, please ensure that the provided task is one of ['sft', 'generation', 'rm', 'dpo', 'orpo']"
f"Task {task} not supported, please ensure that the provided task is one of ['sft', 'generation', 'rm', 'dpo', 'orpo', 'kto']"
)
return example

Expand Down
1 change: 1 addition & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from copy import deepcopy

Expand Down
Loading