diff --git a/recipes/configs/llama3_2/knowledge_distillation_distributed.yaml b/recipes/configs/llama3_2/knowledge_distillation_distributed.yaml new file mode 100644 index 0000000000..1cc864b900 --- /dev/null +++ b/recipes/configs/llama3_2/knowledge_distillation_distributed.yaml @@ -0,0 +1,130 @@ +# Config for multi-device knowledge distillation in knowledge_distillation_distributed.py +# using a teacher and student model +# +# This config assumes that you've ran the following commands before launching KD: +# First download the student and teacher models +# tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth" +# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" +# +# You get better results using KD if the teacher model has already been fine-tuned on the target dataset: +# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama3_1/8B_lora +# +# To launch on 2 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config llama3_2/knowledge_distillation_distributed +# +# This config works best for distilling on 2+ devices. + + +# Model Arguments +model: + _component_: torchtune.models.llama3_2.lora_llama3_2_1b + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True + apply_lora_to_output: False + lora_rank: 64 + lora_alpha: 128 + lora_dropout: 0.0 + +teacher_model: + _component_: torchtune.models.llama3_1.llama3_1_8b + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Llama-3.2-1B-Instruct/original/tokenizer.model + max_seq_len: null + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-3.2-1B-Instruct/ + checkpoint_files: [ + model.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Llama-3.2-1B-Instruct/ + model_type: LLAMA3 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Teacher checkpoint +teacher_checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ + checkpoint_files: [ + model-00001-of-00004.safetensors, + model-00002-of-00004.safetensors, + model-00003-of-00004.safetensors, + model-00004-of-00004.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ + model_type: LLAMA3 + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True +batch_size: 4 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + fused: True + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + +kd_loss: + _component_: torchtune.modules.loss.ForwardKLWithChunkedOutputLoss +kd_ratio: 0.5 + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 32 + +# Logging +output_dir: /tmp/kd_output +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: False + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.training.setup_torch_profiler + + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/qwen2/knowledge_distillation_distributed.yaml b/recipes/configs/qwen2/knowledge_distillation_distributed.yaml new file mode 100644 index 0000000000..9727860ca7 --- /dev/null +++ b/recipes/configs/qwen2/knowledge_distillation_distributed.yaml @@ -0,0 +1,123 @@ +# Config for multi-device knowledge distillation in knowledge_distillation_distributed.py +# using a teacher and student model +# +# This config assumes that you've ran the following commands before launching KD: +# First download the student and teacher models +# tune download Qwen/Qwen2-0.5B-Instruct --output-dir /tmp/Qwen2-0.5B-Instruct --ignore-patterns None +# tune download Qwen/Qwen2-1.5B-Instruct --output-dir /tmp/Qwen2-1.5B-Instruct --ignore-patterns None +# +# You get better results using KD if the teacher model has already been fine-tuned on the target dataset: +# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config qwen2/1.5B_lora +# +# To launch on 2 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config qwen2/knowledge_distillation_distributed +# +# This config works best for distilling on 2+ devices. + + +# Model Arguments +model: + _component_: torchtune.models.qwen2.lora_qwen2_0_5b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: False + lora_rank: 32 + lora_alpha: 64 + +teacher_model: + _component_: torchtune.models.qwen2.qwen2_1_5b + +tokenizer: + _component_: torchtune.models.qwen2.qwen2_tokenizer + path: /tmp/Qwen2-0.5B-Instruct/vocab.json + merges_file: /tmp/Qwen2-0.5B-Instruct/merges.txt + max_seq_len: null + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Qwen2-0.5B-Instruct + checkpoint_files: [ + model.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Qwen2-0.5B-Instruct-kd + model_type: QWEN2 + +teacher_checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune + checkpoint_files: [ + hf_model_0001_0.pt + ] + recipe_checkpoint: null + output_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune + model_type: QWEN2 + +resume_from_checkpoint: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True +batch_size: 8 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + +kd_loss: + _component_: torchtune.modules.loss.ForwardKLWithChunkedOutputLoss +kd_ratio: 0.5 + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 2 + +# Logging +output_dir: /tmp/qwen_kd +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: False + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.training.setup_torch_profiler + + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/knowledge_distillation_distributed.py b/recipes/knowledge_distillation_distributed.py new file mode 100644 index 0000000000..d17e480ba6 --- /dev/null +++ b/recipes/knowledge_distillation_distributed.py @@ -0,0 +1,980 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys +import time + +from functools import partial +from typing import Any, Dict, List, Optional, Tuple, Union +from warnings import warn + +import torch +from omegaconf import DictConfig, ListConfig + +from torch import nn +from torch.distributed import destroy_process_group, init_process_group + +from torch.optim import Optimizer +from torch.utils.data import DataLoader, DistributedSampler +from torchtune import config, modules, training, utils +from torchtune.data import padded_collate_packed, padded_collate_sft +from torchtune.datasets import ConcatDataset +from torchtune.modules.peft import ( + DoRALinear, + get_adapter_params, + get_lora_module_names, + get_merged_lora_ckpt, + load_dora_magnitudes, + LoRALinear, + set_trainable_params, + validate_missing_and_unexpected_for_lora, +) +from torchtune.recipe_interfaces import FTRecipeInterface +from torchtune.training import DummyProfiler, PROFILER_KEY + +from tqdm import tqdm + +log = utils.get_logger("DEBUG") + + +class KDRecipeDistributed(FTRecipeInterface): + """ + Knowledge distillation recipe for dense transformer-based LLMs such as Llama3. This recipe is optimized + for single GPU training. Training on CPU is not supported. + + Features: + - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states + is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is + done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config + ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy). + DDP is currently not supported. Training on CPU is not supported. + + - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` + flag. Activation checkpointing helps reduce the memory footprint since we no longer keep + activations in memory and instead recompute them during the backward pass. This is especially + helpful for larger batch sizes when you're memory constrained. But these savings in memory + come at the cost of training performance. In most cases training can slow-down quite a bit as + a result of this activation recomputation. + + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` + flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In + most cases this should halve the memory footprint of full precision (fp32) training, without + loss in model quality (will depend on the model, training data and other settings). For + GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16 + precision are currently not supported. + + - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is + controlled using the ``gradient_accumulation_steps`` flag. + + Total Batch Size = batch_size * gradient accumulation steps. + + For example: with batch_size=1 and gradient_accumulation_steps=32 we get a total batch size of 32. + + Gradient accumulation is especially useful when you are memory constrained. In this case, + accumulating gradients might give you better training speed than enabling activation + checkpointing. + + - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of + training. Currently we checkpoint both the adapter weights (trainable params only) and the + complete merged weights (adapter weights added back to the base model). For more details + please take a look at our LoRA tutorial + (https://pytorch.org/torchtune/main/tutorials/lora_finetune.html). + + Optimizer State and recipe state (seed, total_epochs, number of epochs run etc) are + only saved at the end of a given epoch and used in case of resuming training. Resuming + training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is + currently not supported. + + For more details on the checkpointer, please take a look at + our checkpointer deepdive (https://pytorch.org/torchtune/main/tutorials/checkpointer.html). + + - Logging. Terminal, Disk, WandB and TensorBoard are all supported. + + For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config + has example commands for how to kick-off training. + + Args: + cfg (DictConfig): OmegaConf object parsed from yaml file + + Raises: + ValueError: If ``dtype`` is set to fp16. + RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. + + """ + + def __init__(self, cfg: DictConfig) -> None: + self._device = utils.get_device(device=cfg.device) + # Reduced precision logic + self._dtype = training.get_dtype(cfg.dtype, device=self._device) + # fp16 precision is explicitly disabled as it is not supported in this + # recipe (for example, no gradient scaling). + if self._dtype == torch.float16: + raise ValueError( + "fp16 precision is not supported in this recipe. Please use fp32 or bf16." + ) + + _, rank = training.get_world_size_and_rank() + + # _is_rank_zero is used primarily for logging. In the future, the logger + # should directly take care of this + self._is_rank_zero = rank == 0 + + # logging attributes + self._output_dir = cfg.output_dir + self._log_every_n_steps = cfg.get("log_every_n_steps", 1) + self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) + + # training attributes + self._enable_activation_checkpointing = cfg.enable_activation_checkpointing + + # These are public properties which are updated by the checkpoint loader + # when ``resume_from_checkpoint`` is `True` or validated in tests + self.seed = training.set_seed(seed=cfg.seed) + self.epochs_run = 0 + self.total_epochs = cfg.epochs + self.max_steps_per_epoch = cfg.max_steps_per_epoch + self.global_step = 0 + + self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False) + self._gradient_accumulation_steps = cfg.gradient_accumulation_steps + self._kd_ratio = cfg.get("kd_ratio", 0.5) + + def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: + """ + Extract the checkpoint state from file and validate. This includes the + base model weights. If resume_from_checkpoint is True, this also includes + the adapter weights and recipe state + """ + self._checkpointer = config.instantiate( + cfg_checkpointer, + resume_from_checkpoint=self._resume_from_checkpoint, + ) + checkpoint_dict = self._checkpointer.load_checkpoint() + + if self._resume_from_checkpoint: + if training.ADAPTER_KEY not in checkpoint_dict: + raise ValueError( + "Adapter weights not found. Please ensure a valid adapter checkpoint is provided." + ) + # _update_recipe_state will throw an exception if the recipe state is not corrctly loaded + # no need to check here + self._update_recipe_state(checkpoint_dict) + return checkpoint_dict + + def load_teacher_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: + """ + Extract the teacher checkpoint state from file. + """ + teacher_checkpointer = config.instantiate( + cfg_checkpointer, + ) + checkpoint_dict = teacher_checkpointer.load_checkpoint() + return checkpoint_dict + + def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: + """ + Updates the recipe state from checkpoint. + """ + try: + self.epochs_run = ckpt_dict[training.EPOCHS_KEY] + + # on mismatch, warn the user and prevent the override + if self.seed != ckpt_dict[training.SEED_KEY]: + warn( + message=( + "Config value for seed does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}" + ) + ) + self.seed = ckpt_dict[training.SEED_KEY] + if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]: + warn( + message=( + "Config value for max_steps_per_epoch does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}" + ) + ) + self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY] + + # on mismatch, warn the user but allow the override + if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]: + warn( + message=( + "Config value for total_epochs does not match the checkpoint value, " + f"using the config value: {self.total_epochs}" + ) + ) + + except KeyError as e: + raise KeyError( + "Checkpoint does not contain the required keys needed for updating recipe state. " + "Are you sure you passed in the right recipe checkpoint?" + ) from e + + def setup(self, cfg: DictConfig) -> None: + """ + Setup the recipe state. This includes recipe state (if resume_from_checkpoint is True), + model, tokenizer, loss, optimizer, learning rate scheduler, sampler, and dataloader. + """ + if self._is_rank_zero: + self._metric_logger = config.instantiate(cfg.metric_logger) + + # log config with parameter override + self._metric_logger.log_config(cfg) + + self._compile = cfg.get("compile", False) + checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) + teacher_checkpoint_dict = self.load_teacher_checkpoint( + cfg_checkpointer=cfg.teacher_checkpointer + ) + + # set up model + self._model = self._setup_model( + cfg_model=cfg.model, + enable_activation_checkpointing=cfg.enable_activation_checkpointing, + fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), + reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), + base_model_state_dict=checkpoint_dict[training.MODEL_KEY], + lora_weights_state_dict=( + checkpoint_dict[training.ADAPTER_KEY] + if self._resume_from_checkpoint + else None + ), + ) + + self._teacher_model = self._setup_teacher_model( + model_cfg=cfg.teacher_model, + custom_sharded_layers=cfg.get("custom_sharded_layers", None), + fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), + reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), + model_state_dict=teacher_checkpoint_dict[training.MODEL_KEY], + ) + + self._tokenizer = config.instantiate(cfg.tokenizer) + log.info("Tokenizer is initialized from file.") + + self._optimizer = self._setup_optimizer( + cfg_optimizer=cfg.optimizer, + opt_state_dict=( + checkpoint_dict[training.OPT_KEY] + if self._resume_from_checkpoint + else None + ), + ) + + # initialize loss + self._loss_fn = config.instantiate(cfg.loss) + self._kd_loss_fn = config.instantiate(cfg.kd_loss) + if self._compile: + self._loss_fn = training.compile_loss( + self._loss_fn, verbose=self._is_rank_zero + ) + self._kd_loss_fn = training.compile_loss( + self._kd_loss_fn, verbose=self._is_rank_zero + ) + + if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": + # set num_output_chunks for model + self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) + self._teacher_model.set_num_output_chunks(self._loss_fn.num_output_chunks) + # assert _loss_fn and _kd_loss_fn have the same num_output_chunks + assert ( + self._loss_fn.num_output_chunks == self._kd_loss_fn.num_output_chunks + ), "Number of output chunks for loss_fn and kd_loss_fn must be the same." + + if self._is_rank_zero: + log.info("Loss is initialized.") + + # Dataloader depends on the tokenizer and loss_fn and should be + # setup after all of these are setup + self._sampler, self._dataloader = self._setup_data( + cfg_dataset=cfg.dataset, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, + ) + + # Finally update the recipe state which can only be correctly set after all of the + # other components have been initialized and updated. + + # Number of training steps in each epoch depends on the number of batches produced + # by the dataloader and the max_steps_per_epoch param set by the user and is used + # for logging and tracking training state. This should be computed after the dataloader + # has been setup + self._steps_per_epoch = ( + len(self._dataloader) // self._gradient_accumulation_steps + ) + if ( + self.max_steps_per_epoch is not None + and self.max_steps_per_epoch < self._steps_per_epoch + ): + self._steps_per_epoch = self.max_steps_per_epoch + self.global_step = self.epochs_run * self._steps_per_epoch + + # Learning rate scheduler can only be set up after number of steps + # has been computed + self._lr_scheduler = self._setup_lr_scheduler( + cfg_lr_scheduler=cfg.lr_scheduler, + num_training_steps=self.total_epochs * self._steps_per_epoch, + last_epoch=self.global_step - 1, + ) + + # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) + # if cfg is missing profiler key or if `cfg.profiler.enabled = False` + self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) + + # Used to ignore labels for loss computation + self.ignore_labels_cache = torch.full( + (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device + ) + + def _setup_profiler( + self, cfg_profiler: Optional[DictConfig] = None + ) -> Union[torch.profiler.profile, DummyProfiler]: + """ + Parses the `profiler` section of top-level `cfg` and sets up profiler + + Args: + cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to + `recipe.main`). Default None. + + Returns: + profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods + for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such + that the instrumented training loop does not need to be changed profiling is disabled. + + The profiler config can be provided in configs under the `profiler` key with the following layout: + + .. code-block:: yaml + profiler: + enabled: bool + + #Output directory of trace artifacts + output_dir: str + + #`torch.profiler.ProfilerActivity` types to trace + cpu: bool + cuda: bool + + #Trace options + profile_memory: bool + with_stack: bool + record_shapes: bool + with_flops: bool + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: int + warmup_steps: int + active_steps: int + num_cycles: int + """ + + # Missing profiler section in config, assume disabled + if cfg_profiler is None: + cfg_profiler = DictConfig({"enabled": False}) + + # Check that component is included and set correctly + if cfg_profiler.get("_component_", None) is None: + cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler" + else: + assert ( + cfg_profiler.get("_component_") + == "torchtune.training.setup_torch_profiler" + ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`" + + profiler, profiler_cfg = config.instantiate(cfg_profiler) + + if self._is_rank_zero: + log.info(f" Profiler config after instantiation: {profiler_cfg}") + + self.profiler_profile_memory = profiler_cfg.get("profile_memory", False) + if profiler_cfg["enabled"]: + self.profiler_wait_steps = profiler_cfg["wait_steps"] + self.profiler_warmup_steps = profiler_cfg["warmup_steps"] + self.profiler_active_steps = profiler_cfg["active_steps"] + + return profiler + + def _setup_model( + self, + cfg_model: DictConfig, + enable_activation_checkpointing: bool, + fsdp_cpu_offload: bool, + reshard_after_forward: bool, + base_model_state_dict: Dict[str, Any], + custom_sharded_layers: Optional[List[str]] = None, + lora_weights_state_dict: Optional[Dict[str, Any]] = None, + ) -> nn.Module: + """ + Model initialization has some important considerations: + a. To minimize GPU peak memory, we initialize the model on meta device with + the right dtype + b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since + full state dicts are loaded with ``torch.load(mmap=True)`` + c. We register (pre-)forward hooks with ``fully_shard`` instead of wrapping `nn.Module` + """ + + self._lora_rank = cfg_model.lora_rank + self._lora_alpha = cfg_model.lora_alpha + self._lora_attn_modules = list(cfg_model.lora_attn_modules) + self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp + self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False) + + if self._is_rank_zero: + log.info( + "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..." + ) + init_start = time.perf_counter() + + with training.set_default_dtype(self._dtype), torch.device("meta"): + model = config.instantiate(cfg_model) + + self.adapter_params = get_adapter_params(model) + set_trainable_params(model, self.adapter_params) + + if self._compile: + training.compile_model(model, verbose=self._is_rank_zero) + + if enable_activation_checkpointing: + training.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} + ) + + # For FSDP sharding + fsdp_shard_conditions = [ + partial( + training.get_shard_conditions, + names_to_match=custom_sharded_layers, + ) + ] + training.shard_model( + model=model, + shard_conditions=fsdp_shard_conditions, + cpu_offload=fsdp_cpu_offload, + reshard_after_forward=reshard_after_forward, + ) + + if lora_weights_state_dict: + lora_missing, lora_unexpected = training.load_from_full_model_state_dict( + model, + lora_weights_state_dict, + self._device, + self._is_rank_zero, + cpu_offload=fsdp_cpu_offload, + ) + else: + lora_missing, lora_unexpected = None, None + + # Initializer for LoRA params and RoPE buffers + with training.set_default_dtype(self._dtype), self._device: + lora_device = "cpu" if fsdp_cpu_offload else self._device + for m in model.modules(): + if ( + isinstance(m, LoRALinear) or isinstance(m, DoRALinear) + ) and not lora_weights_state_dict: + # lora may not be covered in state dict + # if finetune for the 1st time + m.lora_a.to_empty(device=lora_device) + m.lora_b.to_empty(device=lora_device) + m.initialize_parameters() + # RoPE is not covered in state dict + if hasattr(m, "rope_init"): + m.rope_init() + + base_missing, base_unexpected = training.load_from_full_model_state_dict( + model, + base_model_state_dict, + self._device, + self._is_rank_zero, + cpu_offload=fsdp_cpu_offload, + ) + is_dora = False + for m in model.modules(): + if hasattr(m, "initialize_dora_magnitude"): + is_dora = True + m.initialize_dora_magnitude() + if is_dora: + load_dora_magnitudes(model) + validate_missing_and_unexpected_for_lora( + lora_attn_modules=self._lora_attn_modules, + apply_lora_to_mlp=self._apply_lora_to_mlp, + apply_lora_to_output=self._apply_lora_to_output, + base_missing=base_missing, + base_unexpected=base_unexpected, + lora_missing=lora_missing, + lora_unexpected=lora_unexpected, + ) + # Ensure no params and buffers are on meta device + training.validate_no_params_on_meta_device(model) + + if self._is_rank_zero: + log.info( + f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs" + ) + memory_stats = training.get_memory_stats(device=self._device) + training.log_memory_stats(memory_stats) + + # synchronize before training begins + torch.distributed.barrier() + + return model + + def _setup_teacher_model( + self, + model_cfg: DictConfig, + custom_sharded_layers: Optional[List[str]], + fsdp_cpu_offload: bool, + reshard_after_forward: bool, + model_state_dict: Dict[str, Any], + ) -> nn.Module: + """ + Model initialization for teacher model has some important considerations: + a. To minimize GPU peak memory, we initialize the model on meta device with + the right dtype + b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since + full state dicts are loaded with ``torch.load(mmap=True)`` + """ + + if self._is_rank_zero: + log.info( + "FSDP enabled. Instantiating teacher model and loading checkpoint on Rank 0 ..." + ) + init_start = time.perf_counter() + + with training.set_default_dtype(self._dtype), torch.device("meta"): + model = config.instantiate(model_cfg) + + if self._compile: + training.compile_model(model, verbose=self._is_rank_zero) + + # For FSDP sharding + fsdp_shard_conditions = [ + partial( + training.get_shard_conditions, + names_to_match=custom_sharded_layers, + ) + ] + training.shard_model( + model=model, + shard_conditions=fsdp_shard_conditions, + cpu_offload=fsdp_cpu_offload, + reshard_after_forward=reshard_after_forward, + ) + + with training.set_default_dtype(self._dtype), self._device: + for m in model.modules(): + # RoPE is not covered in state dict + if hasattr(m, "rope_init"): + m.rope_init() + + # This method will convert the full model state dict into a sharded state + # dict and load into the model + training.load_from_full_model_state_dict( + model, + model_state_dict, + self._device, + self._is_rank_zero, + strict=True, + cpu_offload=fsdp_cpu_offload, + ) + + # Put model in eval mode. + # Note: This will not disable the dropout applied in SDPA, + # see https://github.com/pytorch/pytorch/issues/124464 + model.eval() + + for p in model.parameters(): + p.requires_grad = False + + # Ensure no params and buffers are on meta device + training.validate_no_params_on_meta_device(model) + + if self._is_rank_zero: + log.info( + f"Instantiating teacher model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs" + ) + memory_stats = training.get_memory_stats(device=self._device) + training.log_memory_stats(memory_stats) + + # synchronize before training begins + torch.distributed.barrier() + + return model + + def _setup_optimizer( + self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None + ) -> Optimizer: + optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) + if opt_state_dict: + training.load_from_full_optimizer_state_dict( + optimizer, + opt_state_dict, + self._device, + ) + + if self._is_rank_zero: + log.info("Optimizer is initialized.") + return optimizer + + def _setup_lr_scheduler( + self, + cfg_lr_scheduler: DictConfig, + num_training_steps: int, + last_epoch: int, + ) -> Optimizer: + lr_scheduler = config.instantiate( + cfg_lr_scheduler, + self._optimizer, + num_training_steps=num_training_steps, + last_epoch=last_epoch, + ) + + if self._is_rank_zero: + log.info("Learning rate scheduler is initialized.") + return lr_scheduler + + def _setup_data( + self, + cfg_dataset: DictConfig, + shuffle: bool, + batch_size: int, + ) -> Tuple[DistributedSampler, DataLoader]: + """ + All data related setup happens here. Currently this recipe only supports + Map-style Datasets which fit into memory and an option for random shuffling. + Samplers, iterable datasets, and streaming datasets are not supported. + """ + world_size, rank = training.get_world_size_and_rank() + + if isinstance(cfg_dataset, ListConfig): + datasets = [ + config.instantiate(single_cfg_dataset, self._tokenizer) + for single_cfg_dataset in cfg_dataset + ] + ds = ConcatDataset(datasets=datasets) + packed = False + else: + ds = config.instantiate(cfg_dataset, self._tokenizer) + packed = cfg_dataset.get("packed", False) + + sampler = DistributedSampler( + ds, + num_replicas=world_size, + rank=rank, + shuffle=shuffle, + seed=0, + ) + dataloader = DataLoader( + dataset=ds, + batch_size=batch_size, + sampler=sampler, + collate_fn=( + partial( + padded_collate_sft, + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, + ) + if not packed + else partial( + padded_collate_packed, + ) + ), + ) + + if self._is_rank_zero: + log.info("Dataset and Sampler are initialized.") + + return sampler, dataloader + + def save_checkpoint(self, epoch: int) -> None: + """ + Checkpoint the state of the recipe. The constructed checkpoint state dict + contains the following information: + - Merged weights with key MODEL_KEY + - Adapter weights with key ADAPTER_KEY + - Relevant recipe state if training is not complete + - If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights + + To correctly resume from training, the adapter weights and recipe state must be provided along with the base model weights. + """ + # final dict passed onto the checkpointer + checkpoint_dict = {} + + intermediate_checkpoint = epoch + 1 < self.total_epochs + # To prevent GPU memory from spiking during checkpoint save, + # we consolidate the full model and optim state dicts on CPU for rank 0 + cpu_state_dict = training.get_full_model_state_dict( + self._model, + self._is_rank_zero, + device=self._device, + ) + + if intermediate_checkpoint: + opt_state_dict = training.get_full_optimizer_state_dict( + self._optimizer, + self._is_rank_zero, + device=self._device, + ) + else: + opt_state_dict = None + + # Now that we have the model and opt state dict, create the actual checkpoint + # to be sent to the checkpointer and ultimately written to file + if self._is_rank_zero: + + # Filter out the adapter keys and weights from the model state dict. These will + # be saved separately + adapter_key_filter = lambda x: x in self.adapter_params + adapter_state_dict = { + k: v for k, v in cpu_state_dict.items() if adapter_key_filter(k) + } + checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict}) + + # merge the adapter weights and base weights to create the model checkpoint + merged_state_dict = get_merged_lora_ckpt( + cpu_state_dict, + rank=self._lora_rank, + alpha=self._lora_alpha, + ) + checkpoint_dict.update({training.MODEL_KEY: merged_state_dict}) + + # if training is in-progress, checkpoint the optimizer state and recipe state + # as well + if intermediate_checkpoint: + checkpoint_dict.update( + { + training.OPT_KEY: opt_state_dict, + training.SEED_KEY: self.seed, + training.EPOCHS_KEY: self.epochs_run, + training.TOTAL_EPOCHS_KEY: self.total_epochs, + training.MAX_STEPS_KEY: self.max_steps_per_epoch, + } + ) + + adapter_config = { + "r": self._lora_rank, + "lora_alpha": self._lora_alpha, + "target_modules": get_lora_module_names( + self._lora_attn_modules, + self._apply_lora_to_mlp, + self._apply_lora_to_output, + ), + "peft_type": "LORA", + } + checkpoint_dict.update({training.ADAPTER_CONFIG: adapter_config}) + self._checkpointer.save_checkpoint( + checkpoint_dict, + epoch=epoch, + intermediate_checkpoint=intermediate_checkpoint, + adapter_only=self._save_adapter_weights_only, + ) + + def _loss_step( + self, batch: Dict[str, torch.Tensor] + ) -> (torch.Tensor, torch.Tensor): + + # Both are shape [b, s] + tokens, labels = batch["tokens"], batch["labels"] + + # Get the attention mask and position ids from the dataset if they + # exist. Currently, only sample packing in PackedDataset returns these + mask = batch.get("mask", None) # shape [b, s, s] + input_pos = batch.get("input_pos", None) # shape [b, s] + + # run model + logits = self._model(tokens, mask=mask, input_pos=input_pos) + + # Shift labels to compute loss + # equivalent to doing labels[..., 1:] and logits[..., :-1, :] + # But this way we dont need to slice the logits. We just add an ignore index to labels. + labels = torch.hstack( + (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) + ) + if not isinstance(logits, list): + labels = labels.reshape(-1) + logits = logits.reshape(-1, logits.size(-1)) + + # Compute KD loss + with torch.no_grad(): + teacher_logits = self._teacher_model(tokens, mask=mask, input_pos=input_pos) + + # Compute kd loss + kd_loss = self._kd_loss_fn(logits, teacher_logits, labels) + + # Compute loss + loss = self._loss_fn(logits, labels) + + # free logits otherwise it peaks backward memory + del logits + del teacher_logits + + return loss, kd_loss + + def train(self) -> None: + """ + The core training loop. + """ + # clean up before training begins + training.cleanup_before_training() + + _, rank = training.get_world_size_and_rank() + + # zero out the gradients before starting training + self._optimizer.zero_grad() + + # Initialize tokens count and running loss (for grad accumulation) + t0 = time.perf_counter() + running_class_loss = 0 + running_kd_loss = 0 + num_tokens = 0 + + self._profiler.start() + # self.epochs_run should be non-zero when we're resuming from a checkpoint + for curr_epoch in range(self.epochs_run, self.total_epochs): + # Update the sampler to ensure data is correctly shuffled across epochs + # in case shuffle is True + self._sampler.set_epoch(curr_epoch) + + pbar = tqdm(total=self._steps_per_epoch, disable=not (rank == 0)) + for idx, batch in enumerate(self._dataloader): + if ( + self.max_steps_per_epoch is not None + and (idx // self._gradient_accumulation_steps) + == self.max_steps_per_epoch + ): + break + + # Start tracking CUDA memory for active steps for just the first epoch + if ( + self._is_rank_zero + and curr_epoch == 0 + and self.profiler_profile_memory + and idx == self.profiler_wait_steps + self.profiler_warmup_steps + ): + torch.cuda.memory._record_memory_history() + + batch = {k: v.to(self._device) for k, v in batch.items()} + + # Calculate the number of unmasked tokens in the current batch + # and increment the total number of tokens seen in the step + current_num_tokens = ( + batch["labels"] != self._loss_fn.ignore_index + ).sum() + num_tokens += current_num_tokens + + class_loss, kd_loss = self._loss_step(batch) + running_class_loss += class_loss * current_num_tokens + running_kd_loss += kd_loss * current_num_tokens + + # Step with optimizer + if (idx + 1) % self._gradient_accumulation_steps == 0: + class_loss = running_class_loss / num_tokens + kd_loss = running_kd_loss / num_tokens + loss = (1 - self._kd_ratio) * class_loss + self._kd_ratio * kd_loss + loss.backward() + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + self._lr_scheduler.step() + # Update the number of steps when the weights are updated + self.global_step += 1 + + class_loss_to_log = class_loss.item() + kd_loss_to_log = kd_loss.item() + loss_to_log = ( + 1 - self._kd_ratio + ) * class_loss_to_log + self._kd_ratio * kd_loss_to_log + pbar.update(1) + pbar.set_description( + f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + ) + + # Log per-step metrics + if ( + self.global_step % self._log_every_n_steps == 0 + and self._is_rank_zero + ): + time_per_step = time.perf_counter() - t0 + log_dict = { + "loss": loss_to_log, + "class_loss": class_loss_to_log, + "kd_loss": kd_loss_to_log, + "lr": self._optimizer.param_groups[0]["lr"], + "tokens_per_second_per_gpu": num_tokens / time_per_step, + } + if self._log_peak_memory_stats: + log_dict.update( + training.get_memory_stats(device=self._device) + ) + self._metric_logger.log_dict( + log_dict, + step=self.global_step, + ) + + # Reset running stats for the next step + running_class_loss = 0 + running_kd_loss = 0 + num_tokens = 0 + t0 = time.perf_counter() + + # Stop tracking CUDA memory now that active steps are complete + if ( + self._is_rank_zero + and curr_epoch == 0 + and self.profiler_profile_memory + and idx + == self.profiler_wait_steps + + self.profiler_warmup_steps + + self.profiler_active_steps + ): + torch.cuda.memory._record_memory_history(enabled=None) + + # Step the profiler + # Note we are stepping each batch, which might not include optimizer step in the trace + # if the schedule cycle doesn't align with gradient accumulation. + self._profiler.step() + + self.epochs_run += 1 + self.save_checkpoint(epoch=curr_epoch) + + self._profiler.stop() + + def cleanup(self) -> None: + if self._is_rank_zero: + self._metric_logger.close() + destroy_process_group() + + +@config.parse +def recipe_main(cfg: DictConfig) -> None: + """ + Entry point for the recipe. + + Configurable parameters are read in the following order: + - Parameters specified in config (see available configs through ``tune ls``) + - Overwritten by arguments from the command-line + """ + if not training.is_distributed(): + raise RuntimeError( + "Distributed finetune recipe should be run via a distributed launcher." + "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" + ) + if cfg.get("fsdp_cpu_offload", False): + # Utilize all available CPU cores for intra-op parallelism. This provides ~2x + # speed up when benchmarking fused AdamW on CPU + training.set_torch_num_threads() + init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") + + config.log_config(recipe_name="KDRecipeDistributed", cfg=cfg) + + recipe = KDRecipeDistributed(cfg=cfg) + recipe.setup(cfg=cfg) + recipe.train() + recipe.cleanup() + + +if __name__ == "__main__": + sys.exit(recipe_main()) diff --git a/tests/recipes/test_knowledge_distillation_distributed.py b/tests/recipes/test_knowledge_distillation_distributed.py new file mode 100644 index 0000000000..949883ac48 --- /dev/null +++ b/tests/recipes/test_knowledge_distillation_distributed.py @@ -0,0 +1,254 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import runpy +import sys +from pathlib import Path + +import pytest +import torch +from omegaconf import OmegaConf +from tests.common import TUNE_PATH +from tests.recipes.utils import ( + CKPT_COMPONENT_MAP, + dummy_alpaca_dataset_config, + MODEL_TEST_CONFIGS, + write_hf_ckpt_config, +) +from tests.test_utils import ( + CKPT_MODEL_PATHS, + gen_log_file_name, + get_loss_values_from_metric_logger, + gpu_test, + TOKENIZER_PATHS, +) +from torchtune import config + + +class TestKDDistributedRecipe: + def _get_test_config_overrides(self, epochs: int = 2): + return [ + "batch_size=4", + "enable_activation_checkpointing=False", + "dataset.train_on_input=False", + "seed=9", + f"epochs={epochs}", + "dtype=fp32", + "max_steps_per_epoch=2", + "optimizer.lr=2e-5", + "log_every_n_steps=1", + "gradient_accumulation_steps=1", + "compile=False", + ] + dummy_alpaca_dataset_config() + + def _fetch_expected_loss_values(self, model_type): + loss_values_map = { + "llama3": [11.8316, 11.7520, 11.7642, 11.7664], + } + return loss_values_map[model_type] + + @pytest.mark.integration_test + @gpu_test(gpu_count=2) + def test_loss(self, tmpdir, monkeypatch): + ckpt = "llama3_tune" + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + tokenizer_path = Path(TOKENIZER_PATHS["llama3"]) + + cmd = f""" + tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed \ + --config llama3_2/knowledge_distillation_distributed \ + output_dir={tmpdir} \ + checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}] \ + checkpointer.output_dir={tmpdir} \ + teacher_checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ + teacher_checkpointer.checkpoint_dir='{ckpt_dir}' \ + teacher_checkpointer.checkpoint_files=[{ckpt_path}] \ + teacher_checkpointer.output_dir={tmpdir} \ + tokenizer.path='{tokenizer_path}' \ + tokenizer.prompt_template=null \ + metric_logger.filename={log_file} \ + """.split() + + model_config = MODEL_TEST_CONFIGS["llama3_lora"] + teacher_config = [ + "teacher_" + config for config in MODEL_TEST_CONFIGS["llama3"] + ] + + cmd = cmd + self._get_test_config_overrides() + model_config + teacher_config + monkeypatch.setattr(sys, "argv", cmd) + runpy.run_path(TUNE_PATH, run_name="__main__") + + loss_values = get_loss_values_from_metric_logger(log_file) + # only take the first loss + num_losses = int(len(loss_values) / 4) # 2 steps per epoch, 2 epochs + loss_values = loss_values[0::num_losses] + expected_loss_values = self._fetch_expected_loss_values("llama3") + torch.testing.assert_close( + loss_values, expected_loss_values, rtol=1e-5, atol=1e-5 + ) + + @pytest.mark.integration_test + @gpu_test(gpu_count=2) + def test_training_state_on_resume(self, tmpdir, monkeypatch): + """Test whether the recipe state is correctly updated on resume. Since this + is model agnostic, we should run this on the small model only. The test + consists of three stages: + - Train a model for 2 epochs + - Resume training after epoch 1 + - Make sure final loss matches the expected value of a model successfully resumed from a ckpt + """ + + ckpt = "llama3_tune" + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + tokenizer_path = Path(TOKENIZER_PATHS["llama3"]) + + # Config file needed for model conversion. + # Create a second copy for training resume + write_hf_ckpt_config(ckpt_dir) + write_hf_ckpt_config(tmpdir) + + # Train for two epochs + cmd_1 = f""" + tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed \ + --config llama3_2/knowledge_distillation_distributed \ + output_dir={tmpdir} \ + checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + teacher_checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ + teacher_checkpointer.checkpoint_dir='{ckpt_dir}' \ + teacher_checkpointer.checkpoint_files=[{ckpt_path}] \ + teacher_checkpointer.output_dir={tmpdir} \ + tokenizer.path={tokenizer_path} \ + tokenizer.prompt_template=null \ + """.split() + + model_config = MODEL_TEST_CONFIGS["llama3_lora"] + teacher_config = [ + "teacher_" + config for config in MODEL_TEST_CONFIGS["llama3"] + ] + + cmd_1 = ( + cmd_1 + self._get_test_config_overrides() + model_config + teacher_config + ) + monkeypatch.setattr(sys, "argv", cmd_1) + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Resume training + cmd_2 = f""" + tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed \ + --config llama3_2/knowledge_distillation_distributed \ + output_dir={tmpdir} \ + checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ + checkpointer.checkpoint_dir={tmpdir} \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.adapter_checkpoint={os.path.join(tmpdir, "adapter_0.pt")} + checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")} + checkpointer.output_dir={tmpdir} \ + teacher_checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ + teacher_checkpointer.checkpoint_dir='{ckpt_dir}' \ + teacher_checkpointer.checkpoint_files=[{ckpt_path}] \ + teacher_checkpointer.output_dir={tmpdir} \ + resume_from_checkpoint=True \ + metric_logger.filename={log_file} \ + tokenizer.path={tokenizer_path} \ + tokenizer.prompt_template=null \ + """.split() + cmd_2 = ( + cmd_2 + + self._get_test_config_overrides(epochs=3) + + model_config + + teacher_config + ) + monkeypatch.setattr(sys, "argv", cmd_2) + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Second epoch only + expected_loss_values = self._fetch_expected_loss_values("llama3")[2:] + loss_values = get_loss_values_from_metric_logger(log_file) + # only take the first loss + num_losses = int(len(loss_values) / 4) # 2 steps per epoch, 2 epochs + loss_values = loss_values[0::num_losses][:2] + + torch.testing.assert_close( + loss_values, expected_loss_values, rtol=1e-5, atol=1e-5 + ) + + @pytest.mark.integration_test + @gpu_test(gpu_count=2) + def test_save_and_load_merged_weights(self, tmpdir, monkeypatch): + ckpt_type = "tune" + model_type = "llama3" + ckpt_component = CKPT_COMPONENT_MAP[ckpt_type] + ckpt = model_type + "_" + ckpt_type + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + tokenizer_path = Path(TOKENIZER_PATHS[model_type]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + + cmd = f""" + tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed \ + --config llama3_2/knowledge_distillation_distributed \ + output_dir={tmpdir} \ + checkpointer._component_={ckpt_component} \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}] \ + checkpointer.output_dir={tmpdir} \ + teacher_checkpointer._component_={ckpt_component} \ + teacher_checkpointer.checkpoint_dir='{ckpt_dir}' \ + teacher_checkpointer.checkpoint_files=[{ckpt_path}] \ + teacher_checkpointer.output_dir={tmpdir} \ + tokenizer.path='{tokenizer_path}' \ + tokenizer.prompt_template=null \ + metric_logger.filename={log_file} \ + """.split() + + model_config = MODEL_TEST_CONFIGS[model_type + "_lora"] + teacher_config = [ + "teacher_" + config for config in MODEL_TEST_CONFIGS[model_type] + ] + + cmd = cmd + self._get_test_config_overrides() + model_config + teacher_config + monkeypatch.setattr(sys, "argv", cmd) + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Next load both the merged weights in a Llama3 base model + # and the base model weights + trained adapter weights in the LoRA Llama 3 model + # The results of calling forward on dummy inputs should be the same. + inputs = torch.randint(low=0, high=32_000, size=(2, 100)) + + # Build LoRA model for loading base + adapter weights separately + lora_model = config.instantiate(OmegaConf.from_dotlist(model_config).model) + + # Build base llama3 model for loading merged weights + base_llama3_config = MODEL_TEST_CONFIGS[model_type] + llama3_model = config.instantiate( + OmegaConf.from_dotlist(base_llama3_config).model + ) + + # Load base model and trained adapter weights into LoRA model and call fwd + with open(f"{tmpdir}/adapter_1.pt", "rb") as f: + lora_sd = torch.load(f, weights_only=True) + with open(ckpt_path, "rb") as f: + base_model_sd = torch.load(f, weights_only=True) + lora_model.load_state_dict(lora_sd, strict=False) + lora_model.load_state_dict(base_model_sd, strict=False) + baseline_out = lora_model(inputs) + + # Load merged final ckpt directly into 3 and call fwd + with open(f"{tmpdir}/torchtune_model_1.pt", "rb") as f: + sd = torch.load(f, weights_only=True) + llama3_model.load_state_dict(sd) + merged_ckpt_out = llama3_model(inputs) + torch.testing.assert_close(baseline_out, merged_ckpt_out, rtol=1e-5, atol=1e-5) diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index 0933bc80ea..7bed74a6e7 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -395,6 +395,21 @@ class Recipe: ], supports_distributed=False, ), + Recipe( + name="knowledge_distillation_distributed", + file_path="knowledge_distillation_distributed.py", + configs=[ + Config( + name="qwen2/knowledge_distillation_distributed", + file_path="qwen2/knowledge_distillation_distributed.yaml", + ), + Config( + name="llama3_2/knowledge_distillation_distributed", + file_path="llama3_2/knowledge_distillation_distributed.yaml", + ), + ], + supports_distributed=True, + ), ]