Skip to content

Commit

Permalink
this won't work
Browse files Browse the repository at this point in the history
  • Loading branch information
ProbablyFaiz committed Sep 20, 2024
1 parent 7cee793 commit fe80f45
Showing 1 changed file with 99 additions and 96 deletions.
195 changes: 99 additions & 96 deletions rl/llm/train_llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# type: ignore
"""A script to train an LLM with LoRA."""
"""A script to train an LLM with LoRA and Unsloth for multi-GPU training."""

import hashlib
from pathlib import Path
Expand All @@ -8,8 +8,15 @@
import pandas as pd
import peft
import torch
import torch.distributed as DISTRIBUTED
import wandb
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
HfArgumentParser,
Trainer,
TrainingArguments,
)

import rl.llm.config
import rl.llm.merge_lora
Expand All @@ -26,14 +33,12 @@
"lora_alpha": 16,
"lora_dropout": 0.0,
"bias": "none",
# "task_type": "CAUSAL_LM",
}

_DPO_BETA = 0.1
_VALIDATION_SPLIT = 0.1
_MAX_VALIDATION_SIZE = 500


_DEFAULT_BASE_OUTPUT_DIR = rl.utils.io.get_model_path("lora")
_DEFAULT_MERGED_DIR = rl.utils.io.get_model_path("merged")

Expand Down Expand Up @@ -193,10 +198,21 @@ def main(
abort=True,
)

if unsloth:
DISTRIBUTED.init_process_group("nccl")
local_rank = DISTRIBUTED.get_rank()
rl.utils.LOGGER.info(f"Unsloth: Loading model on GPU device id = {local_rank}.")
else:
local_rank = 0

full_dataset = get_dataset(train_data_path, val_data_path)

world_size = int(rl.utils.io.getenv("WORLD_SIZE", 1))
model, tokenizer = get_model(base_model_id, model_architecture, world_size, unsloth)
if unsloth:
model, tokenizer = get_unsloth_model(
base_model_id, model_architecture, local_rank
)
else:
model, tokenizer = get_model(base_model_id, model_architecture)

trainer = get_trainer(
model=model,
Expand All @@ -205,26 +221,28 @@ def main(
batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
eval_steps=eval_steps,
world_size=world_size,
output_dir=output_dir,
tokenizer=tokenizer,
full_dataset=full_dataset,
dpo=dpo,
dpo_beta=dpo_beta,
deepspeed_config=deepspeed_config,
unsloth=unsloth,
)

if trainer.accelerator.is_main_process:
if trainer.is_world_process_zero():
log_initial_run_info(base_model_id, name, train_data_path, val_data_path)

trainer.train(resume_from_checkpoint=True if output_exists else None)
trainer.accelerator.wait_for_everyone()
trainer.save_state()

if trainer.accelerator.is_main_process:
if trainer.is_world_process_zero():
save_model(trainer, tokenizer, output_dir)

del trainer, tokenizer, model
if merge_after:
if unsloth:
DISTRIBUTED.destroy_process_group()

if merge_after and trainer.is_world_process_zero():
merged_dir = _DEFAULT_MERGED_DIR / (name or output_dir.name)
rl.llm.merge_lora.main.callback(
base_model_id=base_model_id,
Expand All @@ -239,7 +257,6 @@ def _get_default_output_dir(name: str) -> Path:

def get_dataset(train_data_path: Path, val_data_path: Path) -> datasets.Dataset:
df = pd.read_json(train_data_path, lines=True)
# Removing the metadata because it causes weird problems when loading the dataset.
df = df.drop(columns=["metadata"])
dataset = datasets.Dataset.from_pandas(df)
if val_data_path:
Expand All @@ -254,15 +271,58 @@ def get_dataset(train_data_path: Path, val_data_path: Path) -> datasets.Dataset:
return dataset


def get_tokenizer(base_model_id):
def get_tokenizer(base_model_id: str) -> AutoTokenizer:
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
tokenizer.pad_token_id = 0 # unk
tokenizer.padding_side = "left"
return tokenizer


def get_unsloth_model(
base_model_id: str, model_architecture: str, local_rank: int
) -> tuple:
from unsloth import FastLanguageModel

max_seq_length = 4096
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=base_model_id,
max_seq_length=max_seq_length,
dtype=None,
load_in_4bit=_BNB_CONFIG.load_in_4bit if _BNB_CONFIG else False,
device_map={"": local_rank},
)

if model_architecture in ("llama", "mistral"):
target_modules = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
elif model_architecture == "phi":
target_modules = ["q_proj", "k_proj", "v_proj", "dense", "fc1", "fc2"]
else:
target_modules = LORA_CONFIG.get("target_modules", [])

model = FastLanguageModel.get_peft_model(
model,
r=LORA_CONFIG["r"],
lora_alpha=LORA_CONFIG["lora_alpha"],
target_modules=target_modules,
lora_dropout=LORA_CONFIG["lora_dropout"],
bias=LORA_CONFIG["bias"],
use_gradient_checkpointing="unsloth",
random_state=3407,
max_seq_length=max_seq_length,
)
return model, tokenizer


def get_model(
base_model_id: str, model_architecture: str, world_size: int, unsloth: bool
base_model_id: str, model_architecture: str
) -> tuple[peft.PeftModel, AutoTokenizer]:
if model_architecture in ("llama", "stablelm", "mistral"):
LORA_CONFIG["target_modules"] = [
Expand All @@ -284,27 +344,9 @@ def get_model(
"fc2",
]
else:
# TODO: Just leaving as a placeholder for if/when we add other architectures.
pass

if unsloth:
assert model_architecture in (
"mistral",
"llama",
), "Unsloth only supports Mistral and LLaMA"
assert (
world_size == 1
), "TODO: Figure out how to do multi-GPU training with Unsloth"
assert (
_BNB_CONFIG is None or not _BNB_CONFIG.load_in_8bit
), "Unsloth only supports 4-bit and no quantization"
return _get_unsloth_model(base_model_id)

if world_size > 1:
assert rl.utils.io.getenv("LOCAL_RANK") is not None, "LOCAL_RANK must be set"
device_map = (
{"": int(rl.utils.io.getenv("LOCAL_RANK"))} if world_size > 1 else "auto"
)
device_map = "auto"
rl.utils.LOGGER.info(f"Device map for this process: {device_map}")
model = AutoModelForCausalLM.from_pretrained(
base_model_id,
Expand All @@ -315,82 +357,51 @@ def get_model(
attn_implementation="flash_attention_2",
)
model = peft.prepare_model_for_kbit_training(model)
model = peft.get_peft_model(model, peft.LoraConfig(**LORA_CONFIG)) # type: ignore
model = peft.get_peft_model(model, peft.LoraConfig(**LORA_CONFIG))
model.config.use_cache = False
torch.compile(model)
return model, get_tokenizer(base_model_id)


def _get_unsloth_model(base_model_id: str) -> tuple:
from unsloth import FastLanguageModel

max_seq_length = 4096
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=base_model_id,
max_seq_length=max_seq_length,
dtype=None,
load_in_4bit=_BNB_CONFIG.load_in_4bit if _BNB_CONFIG else False,
)
model = FastLanguageModel.get_peft_model(
model,
**LORA_CONFIG,
use_gradient_checkpointing="unsloth",
random_state=3407,
max_seq_length=max_seq_length,
)
tokenizer = get_tokenizer(base_model_id)
return model, tokenizer


def get_trainer(
*,
model,
learning_rate,
num_epochs,
batch_size,
gradient_accumulation_steps,
eval_steps,
world_size,
output_dir,
tokenizer,
full_dataset,
dpo=False,
dpo_beta=None,
deepspeed_config=None,
):
if world_size > 1:
gradient_accumulation_steps = gradient_accumulation_steps // world_size
if world_size == 1 and torch.cuda.device_count() > 1:
model.is_parallelizable = True
model.model_parallel = True

learning_rate: float,
num_epochs: float,
batch_size: int,
gradient_accumulation_steps: int,
eval_steps: int,
output_dir: Path,
tokenizer: AutoTokenizer,
full_dataset: datasets.Dataset,
dpo: bool = False,
dpo_beta: float = None,
deepspeed_config: Path = None,
unsloth: bool = False,
) -> Trainer:
training_args = TrainingArguments(
# Training hyperparameters
num_train_epochs=num_epochs,
learning_rate=learning_rate,
warmup_steps=15,
optim="paged_adamw_8bit",
optim="adamw_8bit",
bf16=torch.cuda.is_bf16_supported(),
fp16=not torch.cuda.is_bf16_supported(),
gradient_accumulation_steps=gradient_accumulation_steps,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
# Deepspeed
deepspeed=deepspeed_config,
# Save configuration
save_strategy="steps",
eval_strategy="steps",
save_steps=eval_steps,
eval_steps=eval_steps,
save_total_limit=6,
load_best_model_at_end=True,
output_dir=str(output_dir),
save_safetensors=False, # Safetensors breaks because of: https://huggingface.co/docs/safetensors/torch_shared_tensors
# Logging
report_to="wandb", # type: ignore
save_safetensors=False,
report_to="wandb",
logging_steps=2,
# Other
remove_unused_columns=False,
ddp_find_unused_parameters=False if world_size > 1 else None,
)
if batch_size:
training_args.per_device_train_batch_size = batch_size
Expand All @@ -404,7 +415,7 @@ def get_trainer(
assert all(
col in full_dataset["train"].column_names
for col in ("prompt", "chosen", "rejected")
), "DPO training requires 'prompt', 'chosen', and 'rejected' columns in the training data. Did you pass the right file?"
), "DPO training requires 'prompt', 'chosen', and 'rejected' columns in the training data."
from trl import DPOTrainer

trainer_class = DPOTrainer
Expand All @@ -416,8 +427,6 @@ def get_trainer(
tokenizer,
train_on_input=False,
predict_with_generate=False,
# input_max_len=4096,
# output_max_len=2048,
),
}

Expand All @@ -431,7 +440,9 @@ def get_trainer(
return trainer


def log_initial_run_info(base_model_id, run_name, train_data_path, val_data_path):
def log_initial_run_info(
base_model_id: str, run_name: str, train_data_path: Path, val_data_path: Path
):
train_data_md5_hash = hashlib.md5(train_data_path.read_bytes()).hexdigest()
val_data_md5_hash = (
hashlib.md5(val_data_path.read_bytes()).hexdigest()
Expand All @@ -457,15 +468,7 @@ def log_initial_run_info(base_model_id, run_name, train_data_path, val_data_path
)


def confirm_run(output_dir):
if any(output_dir.iterdir()):
click.confirm(
f"{output_dir} already exists and is not empty. Y for resume training, N to abort.",
abort=True,
)


def save_model(trainer, tokenizer, output_dir):
def save_model(trainer: Trainer, tokenizer: AutoTokenizer, output_dir: Path):
output_dir.parent.mkdir(parents=True, exist_ok=True)
trainer.save_model(str(output_dir))
tokenizer.save_pretrained(str(output_dir))
Expand Down

0 comments on commit fe80f45

Please sign in to comment.