diff --git a/rl/llm/train_llm.py b/rl/llm/train_llm.py index 24623b3..af4b73a 100644 --- a/rl/llm/train_llm.py +++ b/rl/llm/train_llm.py @@ -1,5 +1,5 @@ # type: ignore -"""A script to train an LLM with LoRA.""" +"""A script to train an LLM with LoRA and Unsloth for multi-GPU training.""" import hashlib from pathlib import Path @@ -8,8 +8,15 @@ import pandas as pd import peft import torch +import torch.distributed as DISTRIBUTED import wandb -from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + HfArgumentParser, + Trainer, + TrainingArguments, +) import rl.llm.config import rl.llm.merge_lora @@ -26,14 +33,12 @@ "lora_alpha": 16, "lora_dropout": 0.0, "bias": "none", - # "task_type": "CAUSAL_LM", } _DPO_BETA = 0.1 _VALIDATION_SPLIT = 0.1 _MAX_VALIDATION_SIZE = 500 - _DEFAULT_BASE_OUTPUT_DIR = rl.utils.io.get_model_path("lora") _DEFAULT_MERGED_DIR = rl.utils.io.get_model_path("merged") @@ -193,10 +198,21 @@ def main( abort=True, ) + if unsloth: + DISTRIBUTED.init_process_group("nccl") + local_rank = DISTRIBUTED.get_rank() + rl.utils.LOGGER.info(f"Unsloth: Loading model on GPU device id = {local_rank}.") + else: + local_rank = 0 + full_dataset = get_dataset(train_data_path, val_data_path) - world_size = int(rl.utils.io.getenv("WORLD_SIZE", 1)) - model, tokenizer = get_model(base_model_id, model_architecture, world_size, unsloth) + if unsloth: + model, tokenizer = get_unsloth_model( + base_model_id, model_architecture, local_rank + ) + else: + model, tokenizer = get_model(base_model_id, model_architecture) trainer = get_trainer( model=model, @@ -205,26 +221,28 @@ def main( batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation_steps, eval_steps=eval_steps, - world_size=world_size, output_dir=output_dir, tokenizer=tokenizer, full_dataset=full_dataset, dpo=dpo, dpo_beta=dpo_beta, deepspeed_config=deepspeed_config, + unsloth=unsloth, ) - if trainer.accelerator.is_main_process: + if trainer.is_world_process_zero(): log_initial_run_info(base_model_id, name, train_data_path, val_data_path) trainer.train(resume_from_checkpoint=True if output_exists else None) - trainer.accelerator.wait_for_everyone() + trainer.save_state() - if trainer.accelerator.is_main_process: + if trainer.is_world_process_zero(): save_model(trainer, tokenizer, output_dir) - del trainer, tokenizer, model - if merge_after: + if unsloth: + DISTRIBUTED.destroy_process_group() + + if merge_after and trainer.is_world_process_zero(): merged_dir = _DEFAULT_MERGED_DIR / (name or output_dir.name) rl.llm.merge_lora.main.callback( base_model_id=base_model_id, @@ -239,7 +257,6 @@ def _get_default_output_dir(name: str) -> Path: def get_dataset(train_data_path: Path, val_data_path: Path) -> datasets.Dataset: df = pd.read_json(train_data_path, lines=True) - # Removing the metadata because it causes weird problems when loading the dataset. df = df.drop(columns=["metadata"]) dataset = datasets.Dataset.from_pandas(df) if val_data_path: @@ -254,15 +271,58 @@ def get_dataset(train_data_path: Path, val_data_path: Path) -> datasets.Dataset: return dataset -def get_tokenizer(base_model_id): +def get_tokenizer(base_model_id: str) -> AutoTokenizer: tokenizer = AutoTokenizer.from_pretrained(base_model_id) tokenizer.pad_token_id = 0 # unk tokenizer.padding_side = "left" return tokenizer +def get_unsloth_model( + base_model_id: str, model_architecture: str, local_rank: int +) -> tuple: + from unsloth import FastLanguageModel + + max_seq_length = 4096 + model, tokenizer = FastLanguageModel.from_pretrained( + model_name=base_model_id, + max_seq_length=max_seq_length, + dtype=None, + load_in_4bit=_BNB_CONFIG.load_in_4bit if _BNB_CONFIG else False, + device_map={"": local_rank}, + ) + + if model_architecture in ("llama", "mistral"): + target_modules = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ] + elif model_architecture == "phi": + target_modules = ["q_proj", "k_proj", "v_proj", "dense", "fc1", "fc2"] + else: + target_modules = LORA_CONFIG.get("target_modules", []) + + model = FastLanguageModel.get_peft_model( + model, + r=LORA_CONFIG["r"], + lora_alpha=LORA_CONFIG["lora_alpha"], + target_modules=target_modules, + lora_dropout=LORA_CONFIG["lora_dropout"], + bias=LORA_CONFIG["bias"], + use_gradient_checkpointing="unsloth", + random_state=3407, + max_seq_length=max_seq_length, + ) + return model, tokenizer + + def get_model( - base_model_id: str, model_architecture: str, world_size: int, unsloth: bool + base_model_id: str, model_architecture: str ) -> tuple[peft.PeftModel, AutoTokenizer]: if model_architecture in ("llama", "stablelm", "mistral"): LORA_CONFIG["target_modules"] = [ @@ -284,27 +344,9 @@ def get_model( "fc2", ] else: - # TODO: Just leaving as a placeholder for if/when we add other architectures. pass - if unsloth: - assert model_architecture in ( - "mistral", - "llama", - ), "Unsloth only supports Mistral and LLaMA" - assert ( - world_size == 1 - ), "TODO: Figure out how to do multi-GPU training with Unsloth" - assert ( - _BNB_CONFIG is None or not _BNB_CONFIG.load_in_8bit - ), "Unsloth only supports 4-bit and no quantization" - return _get_unsloth_model(base_model_id) - - if world_size > 1: - assert rl.utils.io.getenv("LOCAL_RANK") is not None, "LOCAL_RANK must be set" - device_map = ( - {"": int(rl.utils.io.getenv("LOCAL_RANK"))} if world_size > 1 else "auto" - ) + device_map = "auto" rl.utils.LOGGER.info(f"Device map for this process: {device_map}") model = AutoModelForCausalLM.from_pretrained( base_model_id, @@ -315,68 +357,40 @@ def get_model( attn_implementation="flash_attention_2", ) model = peft.prepare_model_for_kbit_training(model) - model = peft.get_peft_model(model, peft.LoraConfig(**LORA_CONFIG)) # type: ignore + model = peft.get_peft_model(model, peft.LoraConfig(**LORA_CONFIG)) model.config.use_cache = False torch.compile(model) - return model, get_tokenizer(base_model_id) - - -def _get_unsloth_model(base_model_id: str) -> tuple: - from unsloth import FastLanguageModel - - max_seq_length = 4096 - model, tokenizer = FastLanguageModel.from_pretrained( - model_name=base_model_id, - max_seq_length=max_seq_length, - dtype=None, - load_in_4bit=_BNB_CONFIG.load_in_4bit if _BNB_CONFIG else False, - ) - model = FastLanguageModel.get_peft_model( - model, - **LORA_CONFIG, - use_gradient_checkpointing="unsloth", - random_state=3407, - max_seq_length=max_seq_length, - ) + tokenizer = get_tokenizer(base_model_id) return model, tokenizer def get_trainer( *, model, - learning_rate, - num_epochs, - batch_size, - gradient_accumulation_steps, - eval_steps, - world_size, - output_dir, - tokenizer, - full_dataset, - dpo=False, - dpo_beta=None, - deepspeed_config=None, -): - if world_size > 1: - gradient_accumulation_steps = gradient_accumulation_steps // world_size - if world_size == 1 and torch.cuda.device_count() > 1: - model.is_parallelizable = True - model.model_parallel = True - + learning_rate: float, + num_epochs: float, + batch_size: int, + gradient_accumulation_steps: int, + eval_steps: int, + output_dir: Path, + tokenizer: AutoTokenizer, + full_dataset: datasets.Dataset, + dpo: bool = False, + dpo_beta: float = None, + deepspeed_config: Path = None, + unsloth: bool = False, +) -> Trainer: training_args = TrainingArguments( - # Training hyperparameters num_train_epochs=num_epochs, learning_rate=learning_rate, warmup_steps=15, - optim="paged_adamw_8bit", + optim="adamw_8bit", bf16=torch.cuda.is_bf16_supported(), fp16=not torch.cuda.is_bf16_supported(), gradient_accumulation_steps=gradient_accumulation_steps, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, - # Deepspeed deepspeed=deepspeed_config, - # Save configuration save_strategy="steps", eval_strategy="steps", save_steps=eval_steps, @@ -384,13 +398,10 @@ def get_trainer( save_total_limit=6, load_best_model_at_end=True, output_dir=str(output_dir), - save_safetensors=False, # Safetensors breaks because of: https://huggingface.co/docs/safetensors/torch_shared_tensors - # Logging - report_to="wandb", # type: ignore + save_safetensors=False, + report_to="wandb", logging_steps=2, - # Other remove_unused_columns=False, - ddp_find_unused_parameters=False if world_size > 1 else None, ) if batch_size: training_args.per_device_train_batch_size = batch_size @@ -404,7 +415,7 @@ def get_trainer( assert all( col in full_dataset["train"].column_names for col in ("prompt", "chosen", "rejected") - ), "DPO training requires 'prompt', 'chosen', and 'rejected' columns in the training data. Did you pass the right file?" + ), "DPO training requires 'prompt', 'chosen', and 'rejected' columns in the training data." from trl import DPOTrainer trainer_class = DPOTrainer @@ -416,8 +427,6 @@ def get_trainer( tokenizer, train_on_input=False, predict_with_generate=False, - # input_max_len=4096, - # output_max_len=2048, ), } @@ -431,7 +440,9 @@ def get_trainer( return trainer -def log_initial_run_info(base_model_id, run_name, train_data_path, val_data_path): +def log_initial_run_info( + base_model_id: str, run_name: str, train_data_path: Path, val_data_path: Path +): train_data_md5_hash = hashlib.md5(train_data_path.read_bytes()).hexdigest() val_data_md5_hash = ( hashlib.md5(val_data_path.read_bytes()).hexdigest() @@ -457,15 +468,7 @@ def log_initial_run_info(base_model_id, run_name, train_data_path, val_data_path ) -def confirm_run(output_dir): - if any(output_dir.iterdir()): - click.confirm( - f"{output_dir} already exists and is not empty. Y for resume training, N to abort.", - abort=True, - ) - - -def save_model(trainer, tokenizer, output_dir): +def save_model(trainer: Trainer, tokenizer: AutoTokenizer, output_dir: Path): output_dir.parent.mkdir(parents=True, exist_ok=True) trainer.save_model(str(output_dir)) tokenizer.save_pretrained(str(output_dir))