Skip to content

Commit

Permalink
DPO -> FSDP2 (pytorch#1536)
Browse files Browse the repository at this point in the history
  • Loading branch information
ebsmothers authored Sep 11, 2024
1 parent 8451b0d commit df29d8a
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 115 deletions.
8 changes: 3 additions & 5 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class FullFinetuneRecipeDistributed(FTRecipeInterface):
Features:
- FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states
is supported via the ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is
is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is
done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config
``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy).
DDP is currently not supported. Training on CPU is not supported.
Expand Down Expand Up @@ -123,9 +123,6 @@ def __init__(self, cfg: DictConfig) -> None:
# Training cfg
self._resume_from_checkpoint = cfg.resume_from_checkpoint
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
self._fsdp_sharding_strategy = torch.distributed.fsdp.ShardingStrategy[
cfg.get("fsdp_sharding_strategy", "FULL_SHARD")
]

# These are public properties which are updated by the checkpoint loader
# when ``resume_from_checkpoint`` is `True` or validated in tests
Expand Down Expand Up @@ -233,7 +230,8 @@ def setup(self, cfg: DictConfig) -> None:
# set num_output_chunks for model
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)

log.info("Loss is initialized.")
if self._is_rank_zero:
log.info("Loss is initialized.")

# sampler and dataloader depend on the tokenizer and loss_fn and should be
# setup after both of these are initialized
Expand Down
235 changes: 127 additions & 108 deletions recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,6 @@

from torch import nn
from torch.distributed import destroy_process_group, init_process_group
from torch.distributed.fsdp import (
FullOptimStateDictConfig,
FullStateDictConfig,
FullyShardedDataParallel as FSDP,
StateDictType,
)
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, training, utils
Expand All @@ -30,10 +24,13 @@
from torchtune.modules import rlhf
from torchtune.modules.peft import (
disable_adapter,
DoRALinear,
get_adapter_params,
get_merged_lora_ckpt,
load_dora_magnitudes,
LoRALinear,
set_trainable_params,
validate_state_dict_for_lora,
validate_missing_and_unexpected_for_lora,
)
from torchtune.modules.rlhf.loss import SimPOLoss
from torchtune.recipe_interfaces import FTRecipeInterface
Expand All @@ -49,11 +46,11 @@ class LoRADPORecipeDistributed(FTRecipeInterface):
in the TRL library: https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py#L65
Features:
- FSDP. Supported using PyTorch's FSDP APIs. This can be parameterized using the
``fsdp_sharding_strategy`` config option. You can pass any value supported by
torch.distributed.fsdp.ShardingStrategy
(https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.ShardingStrategy).
For example, in your config, simply pass ``fsdp_sharding=NO_SHARD`` for DDP.
- FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states
is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is
done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config
``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy).
DDP is currently not supported. Training on CPU is not supported.
- Activation Checkpointing. This can be controlled using the ``activation_checkpointing``
flag. Activation checkpointing helps reduce the memory footprint since we no longer keep
Expand Down Expand Up @@ -148,9 +145,6 @@ def __init__(self, cfg: DictConfig) -> None:
self._resume_from_checkpoint = cfg.resume_from_checkpoint
self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False)
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
self._fsdp_sharding_strategy = torch.distributed.fsdp.ShardingStrategy[
cfg.get("fsdp_sharding_strategy", "FULL_SHARD")
]

def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -234,6 +228,8 @@ def setup(self, cfg: DictConfig) -> None:
self._model = self._setup_model(
cfg_model=cfg.model,
enable_activation_checkpointing=cfg.enable_activation_checkpointing,
fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False),
reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True),
base_model_state_dict=checkpoint_dict[training.MODEL_KEY],
lora_weights_state_dict=(
checkpoint_dict[training.ADAPTER_KEY]
Expand All @@ -251,7 +247,8 @@ def setup(self, cfg: DictConfig) -> None:
)

self._loss_fn = config.instantiate(cfg.loss)
log.info("Loss function is initialized.")
if self._is_rank_zero:
log.info("Loss is initialized.")

# sampler and dataloader depend on the tokenizer and loss_fn and should be
# setup after all of these are setup
Expand Down Expand Up @@ -290,102 +287,119 @@ def _setup_model(
self,
cfg_model: DictConfig,
enable_activation_checkpointing: bool,
fsdp_cpu_offload: bool,
reshard_after_forward: bool,
base_model_state_dict: Dict[str, Any],
lora_weights_state_dict: Optional[Dict[str, Any]] = None,
) -> nn.Module:
"""
Model initialization has some important considerations:
a. To minimize GPU peak memory, we load the model on CPU with the right
dtype. To ensure that we don't instantiate ``world_size`` number of models,
we initialize on meta_device for all ranks other than rank 0.
b. Rank 0 is also responsible for calling ``load_state_dict`` and loading the
model weights from checkpoint.
c. While wrapping the model with FSDP, we set ``sync_module_states``
to TRUE and broadcast module params and buffers from rank 0.
d. The ``device_id`` param ensures that the FSDP initialization happens on
the correct device.
a. To minimize GPU peak memory, we initialize the model on meta device with
the right dtype
b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since
full state dicts are loaded with ``torch.load(mmap=True)``
c. We register (pre-)forward hooks with ``fully_shard`` instead of wrapping `nn.Module`
"""

if self._is_rank_zero:
log.info("FSDP is enabled. Instantiating Model on CPU for Rank 0 ...")
log.info(
"FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..."
)
init_start = time.perf_counter()

with training.set_default_dtype(self._dtype):
model = config.instantiate(cfg_model)
with training.set_default_dtype(self._dtype), torch.device("meta"):
model = config.instantiate(cfg_model)

log.info(
f"Model instantiation took {time.perf_counter() - init_start:.2f} secs"
self.adapter_params = get_adapter_params(model)
set_trainable_params(model, self.adapter_params)

if enable_activation_checkpointing:
training.set_activation_checkpointing(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)

# The model contains LoRA params which won't have any matching keys in
# the state dict. As a result, we need to load with strict=False.
# Before loading the state dict, ensure the state dict keys for the base
# model and adapters (if available) match the keys in the full LoRA model
# This is a good sanity check to prevent silent errors
validate_state_dict_for_lora(
lora_attn_modules=cfg_model.lora_attn_modules,
apply_lora_to_mlp=cfg_model.apply_lora_to_mlp,
apply_lora_to_output=cfg_model.apply_lora_to_output,
full_model_state_dict_keys=model.state_dict().keys(),
lora_state_dict_keys=(
lora_weights_state_dict.keys()
if lora_weights_state_dict is not None
else None
),
base_model_state_dict_keys=base_model_state_dict.keys(),
# For FSDP sharding, we can condition on either the module or its name
# Shard conditions should be callables taking name (relative to model root)
# and the module itself and returning a bool on whether to shard the given module

# Shard transformer decoder layers (or AC-wrapped versions)
# Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper)
# But directly using the name is more concise
def _is_layer_name(name: str, module: nn.Module) -> bool:
"""
Return True for layers.i and False for all other module names
Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot
"""
name_list = name.split(".")
return (
len(name_list) == 2
and name_list[0] == "layers"
and str.isdigit(name_list[1])
)

# Load both the base model weights and (if available) the adapter weights. Both
# of this should happen only on Rank 0
model.load_state_dict(base_model_state_dict, strict=False)
if lora_weights_state_dict:
model.load_state_dict(lora_weights_state_dict, strict=False)
training.shard_model(
model=model,
shard_conditions=[_is_layer_name],
cpu_offload=fsdp_cpu_offload,
reshard_after_forward=reshard_after_forward,
)

if lora_weights_state_dict:
lora_missing, lora_unexpected = training.load_from_full_model_state_dict(
model,
lora_weights_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
else:
# For non-zero ranks, load the model on meta device
with training.set_default_dtype(self._dtype), torch.device("meta"):
model = config.instantiate(cfg_model)

if self._dtype == torch.bfloat16:
model = model.to(torch.bfloat16)

# LoRA hyper-params needed for merging weights while saving checkpoints
self._lora_rank = cfg_model.lora_rank
self._lora_alpha = cfg_model.lora_alpha
lora_missing, lora_unexpected = None, None

# Note: this needs to be set before wrapping with FSDP
self.adapter_params = get_adapter_params(model)
set_trainable_params(model, self.adapter_params)

model = FSDP(
module=model,
auto_wrap_policy=training.lora_fsdp_wrap_policy(
modules_to_wrap={modules.TransformerSelfAttentionLayer}
),
sharding_strategy=self._fsdp_sharding_strategy,
device_id=self._device,
# this recipe does not currently support mixed precision training
mixed_precision=None,
# Ensure we broadcast params and buffers from rank 0
sync_module_states=True,
# Initialize empty modules on all non-zero ranks
param_init_fn=(
lambda module: (
module.to_empty(device=torch.device("cuda"), recurse=False)
if not self._is_rank_zero
else None
)
),
# Initialize LoRA params and RoPE buffers
with training.set_default_dtype(self._dtype), self._device:
lora_device = "cpu" if fsdp_cpu_offload else self._device
for m in model.modules():
if (
isinstance(m, LoRALinear) or isinstance(m, DoRALinear)
) and not lora_weights_state_dict:
# lora may not be covered in state dict
# if finetune for the 1st time
m.lora_a.to_empty(device=lora_device)
m.lora_b.to_empty(device=lora_device)
m.initialize_parameters()
# RoPE is not covered in state dict
if hasattr(m, "rope_init"):
m.rope_init()

base_missing, base_unexpected = training.load_from_full_model_state_dict(
model,
base_model_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
is_dora = False
for m in model.modules():
if hasattr(m, "initialize_dora_magnitude"):
is_dora = True
m.initialize_dora_magnitude()
if is_dora:
load_dora_magnitudes(model)
validate_missing_and_unexpected_for_lora(
lora_attn_modules=self._lora_attn_modules,
apply_lora_to_mlp=self._apply_lora_to_mlp,
apply_lora_to_output=self._apply_lora_to_output,
base_missing=base_missing,
base_unexpected=base_unexpected,
lora_missing=lora_missing,
lora_unexpected=lora_unexpected,
)

# Ensure no params and buffers are on meta device
training.validate_no_params_on_meta_device(model)

if enable_activation_checkpointing:
training.set_activation_checkpointing(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)
if self._is_rank_zero:
log.info(
f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs"
)
memory_stats = training.get_memory_stats(device=self._device)
training.log_memory_stats(memory_stats)

Expand All @@ -399,12 +413,11 @@ def _setup_optimizer(
) -> Optimizer:
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
# Note: technically we should check _contains_fsdp for
# just the state dict of the adapter cfg, but should be equivalent
opt_state_dict = FSDP.optim_state_dict_to_load(
self._model, optimizer, opt_state_dict
training.load_from_full_optimizer_state_dict(
optimizer,
opt_state_dict,
self._device,
)
optimizer.load_state_dict(opt_state_dict)

if self._is_rank_zero:
log.info("Optimizer and loss are initialized.")
Expand Down Expand Up @@ -480,25 +493,28 @@ def save_checkpoint(
- Relevant recipe state if training is not complete
- If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights
To correctly resume from training, the adapter weights and recipe state must be provided along with the base model weights.
"""
Checkpointer will save the merged weights, adapter weights and recipe state in
different checkpoint files. To correctly resume from training, the adapter weights
and recipe state must be provided along with the base model weights."""
# final dict passed onto the checkpointer
checkpoint_dict = {}

intermediate_checkpoint = epoch + 1 < self.total_epochs
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
with FSDP.state_dict_type(
cpu_state_dict = training.get_full_model_state_dict(
self._model,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
):
cpu_state_dict = self._model.state_dict()
if intermediate_checkpoint:
opt_state_dict = FSDP.optim_state_dict(self._model, self._optimizer)
else:
opt_state_dict = None
self._is_rank_zero,
device=self._device,
)
if intermediate_checkpoint:
opt_state_dict = training.get_full_optimizer_state_dict(
self._optimizer,
self._is_rank_zero,
device=self._device,
)
else:
opt_state_dict = None

# Now that we have the model and opt state dict, create the actual checkpoint dict
# to be sent to the checkpointer and ultimately written to file
Expand Down Expand Up @@ -731,7 +747,10 @@ def recipe_main(cfg: DictConfig) -> None:
"Distributed finetune recipe should be run via a distributed launcher."
"If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]"
)

if cfg.get("fsdp_cpu_offload", False):
# Utilize all available CPU cores for intra-op parallelism. This provides ~2x
# speed up when benchmarking fused AdamW on CPU
training.set_torch_num_threads()
init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl")

config.log_config(recipe_name="LoRADPORecipeDistributed", cfg=cfg)
Expand Down
7 changes: 5 additions & 2 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,11 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface):
distributed training and can be run on a single node (1 to 8 GPUs).
Features:
- FSDP. Supported using PyTorch's FSDP APIs. DDP is currently not supported. Traning on CPU is not
supported.
- FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states
is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is
done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config
``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy).
DDP is currently not supported. Training on CPU is not supported.
- Activation Checkpointing. This can be controlled using the ``activation_checkpointing``
flag. Activation checkpointing helps reduce the memory footprint since we no longer keep
Expand Down

0 comments on commit df29d8a

Please sign in to comment.