Skip to content

Commit

Permalink
Merge pull request #245 from huggingface/nouamane/fix-optim-states-re…
Browse files Browse the repository at this point in the history
…suming

Small fixes when resuming training
  • Loading branch information
NouamaneTazi authored Nov 21, 2024
2 parents 51ca40b + ab8c145 commit 42040ae
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 19 deletions.
10 changes: 8 additions & 2 deletions src/nanotron/optim/named_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def __init__(
for param in _params:
# https://github.com/pytorch/pytorch/issues/100701
assert param.numel() > 0

super().__init__(optimizer=optimizer_builder(params), id_to_name=id_to_name)

def state_dict(self) -> dict:
Expand All @@ -60,9 +59,16 @@ def state_dict(self) -> dict:
return optim_state_dict

def load_state_dict(self, state_dict: dict) -> None:
# TODO @thomasw21: Make a more robust test
assert set(self.id_to_name.values()) == set(
state_dict["names"].values()
), f"Elements don't match:\n - Elements in `self.id_to_name` that aren't in the other one: {set(self.id_to_name.values()) - set(state_dict['names'].values())}\n - Elements in `state_dict[\"names\"]` that aren't in the other one: {set(state_dict['names'].values()) - set(self.id_to_name.values())}"

OPTIMIZER_STATE_KEYS = sorted(state_dict["state"][0].keys() - {"step"})
assert len(state_dict["state"]) == len(state_dict["names"])
for key in OPTIMIZER_STATE_KEYS:
for k, state in state_dict["state"].items():
assert (
key in state
), f"Key {key} not found in state dict: {state} which corresponds to param_name: {state_dict['names'][k]}"

return super().load_state_dict(state_dict)
21 changes: 21 additions & 0 deletions src/nanotron/sanity_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def before_tbi_sanity_checks(
parallel_context: ParallelContext,
unwrapped_model: NanotronModel,
grad_accumulator: GradientAccumulator,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
) -> None:
if not config.general.ignore_sanity_checks:
# SANITY CHECK: Check that the model params are synchronized across dp
Expand All @@ -84,6 +85,17 @@ def before_tbi_sanity_checks(
msg=lambda err: f"[Before train] Tied weights {name} are not synchronized. {err}",
)

# SANITY CHECK: Check that model grads are zeroed or None
for name, param in unwrapped_model.named_parameters():
if param.grad is not None:
torch.testing.assert_close(
param.grad,
torch.zeros_like(param.grad),
atol=0,
rtol=0,
msg="Model half precision grads must be zeroed or None in first accumulation step.",
)

# SANITY CHECK: Check that the grad accumulator buffers are ready for DDP
if grad_accumulator is not None:
for _, elt in grad_accumulator.fp32_grad_buffers.items():
Expand All @@ -96,6 +108,15 @@ def before_tbi_sanity_checks(
msg="Grad accumulator buffers must be zeroed in first accumulation step.",
)

# TODO: add checks for memory contiguousness

# SANITY CHECK: Check that optimizer's lr is synchronized with lr_scheduler
for i, group in enumerate(lr_scheduler.optimizer.param_groups):
assert (
group["lr"] == lr_scheduler.get_last_lr()[i]
), f"Optimizer and LR scheduler are not in sync. Got {group['lr']} and {lr_scheduler.get_last_lr()[i]}"
break

# SANITY CHECK: run model specific sanity checks
unwrapped_model.before_tbi_sanity_checks()

Expand Down
59 changes: 44 additions & 15 deletions src/nanotron/serialize/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Optional, Tuple

Expand Down Expand Up @@ -147,6 +149,9 @@ def load_optimizer(
if int(ckp_tp_size) != int(parallel_context.tp_pg.size()) or int(ckp_pp_size) != int(
parallel_context.pp_pg.size()
):
warnings.warn(
"You are resuming in a different PP size, so optimizer states need to be checked. Feel free to open a PR if you work on this!"
)
assert (
param_shard_metadata is not None
), f"You have to pass how the original parameters are sharded in order to resume in a different tensor parallel size, ckp_tp_size: {ckp_tp_size}, current tp_size: {parallel_context.tp_pg.size()}"
Expand Down Expand Up @@ -174,18 +179,24 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
# NOTE: if the checkpoint is from a Zero-0 optimizer, then we don't need to merge the shards
# across data parallel dimension, just directly load the checkpoints
shard_paths = list(
root_folder.glob(f"{ObjectType.OPTIMIZER.value}_pp-*-of-{ckp_pp_size}_tp-*-of-{ckp_tp_size}.pt")
root_folder.glob(
f"{ObjectType.OPTIMIZER.value}_pp-*-of-{ckp_pp_size}_tp-*-of-{ckp_tp_size}.pt"
) # WARN: wildcard here after tp can hold `0-of-1_exp-0`
)

ckp_sharded_optim_states = {}
for shard_path in shard_paths:
pp_rank, tp_rank = extract_parallel_ranks_from_shard_path(shard_path, is_zero1=False)
ckp_sharded_optim_states[(pp_rank, tp_rank)] = torch.load(shard_path, map_location=map_location)
ckp_sharded_optim_states[(pp_rank, tp_rank)] = torch.load(
shard_path, map_location=map_location
) # load all optim states in mem

model_state_dict = model.state_dict()
new_optim_state_dict = optimizer.state_dict()
new_optim_state_dict["state"] = defaultdict(dict)
# TODO: this does not handle the edge case of different pipeline parallel optimizer state shards saving different state keys
OPTIMIZER_STATE_NAMES = sorted(ckp_sharded_optim_states[(0, 0)]["state"][0].keys() - ["step"])
OPTIMIZER_STATE_DTYPE = ckp_sharded_optim_states[(0, 0)]["state"][0][OPTIMIZER_STATE_NAMES[0]].dtype
# NOTE: because we can only resume training with the same optimizer type
# (0, 0) = (pp_rank, tp_rank)
# NOTE: also we don't merge "step" because it's just a scalar
Expand Down Expand Up @@ -224,14 +235,14 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
# from an unsharded optimizer state's shape
new_shard_metadata = param.get_sharded_info()
new_unshared_shape = new_shard_metadata.unsharded_shape
new_optim_state_dict["state"][param_index] = {}
# NOTE: restore each state tensor (e.g. exg_avg) by iterating through
# the optimizer state shards saved using the previous topology
for state_key in OPTIMIZER_STATE_NAMES:
# TODO(xrsrke): free the memory of the shards that isn't
# corresponding to the current rank
buffer = torch.zeros_like(param, device="cuda")
unsharded_buffer = torch.empty(new_unshared_shape, device="cuda")
# TODO: maybe better to allocate memory for all states at once
buffer = torch.zeros_like(param, device="cuda", dtype=OPTIMIZER_STATE_DTYPE)
unsharded_buffer = torch.empty(new_unshared_shape, device="cuda", dtype=OPTIMIZER_STATE_DTYPE)

for (pp_rank, tp_rank), ckp_optim_state in ckp_sharded_optim_states.items():
old_optim_state_index = find_optim_index_from_param_name(
Expand Down Expand Up @@ -266,17 +277,34 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
],
new_shard_metadata,
)
else:
# Handle non-sharded params (e.g. layernorm)
for (pp_rank, tp_rank), ckp_optim_state in ckp_sharded_optim_states.items():
old_optim_state_index = find_optim_index_from_param_name(
base_name, ckp_sharded_optim_states, is_zero1=False, pp_rank=pp_rank
)
if old_optim_state_index is None:
continue # Param not in this PP shard

if ckp_optim_type == ZeroDistributedOptimizer.__name__:
# NOTE: flatten the optimizer states
new_optim_state_dict["state"][param_index][state_key] = new_optim_state_dict["state"][
param_index
][state_key].flatten()
# NOTE: a bit awkward, but while we're already reading this (pp,tp) shard for whatever state_key,
# try to get the step value as well.
step = ckp_optim_state["state"][old_optim_state_index].get("step")
if step is not None:
new_optim_state_dict["state"][param_index]["step"] = step
# For non-sharded params, just copy over the state directly
for state_key in OPTIMIZER_STATE_NAMES:
new_optim_state_dict["state"][param_index][state_key] = ckp_optim_state["state"][
old_optim_state_index
][state_key]

if ckp_optim_type == ZeroDistributedOptimizer.__name__:
# NOTE: flatten the optimizer states
new_optim_state_dict["state"][param_index][state_key] = new_optim_state_dict["state"][param_index][
state_key
].flatten()

# NOTE: a bit awkward, but while we're already reading this (pp,tp) shard for whatever state_key,
# try to get the step value as well.
step = ckp_optim_state["state"][old_optim_state_index].get("step")
if step is not None:
new_optim_state_dict["state"][param_index]["step"] = step

# NOTE: we throw away ckp_optim_state['gradient_accumulator'] which has fp32 grads

new_optim_state_dict["names"] = new_optim_state_param_names
state_dict = new_optim_state_dict
Expand Down Expand Up @@ -319,3 +347,4 @@ def load_lr_scheduler(

state_dict = torch.load(root_folder / lr_scheduler_filename())
lr_scheduler.load_state_dict(state_dict)
lr_scheduler._initial_step() # NOTE: this is required to set the initial learning rate
6 changes: 4 additions & 2 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def __init__(
parallel_context=self.parallel_context,
root_folder=self.init_checkpoint_path,
param_shard_metadata=self.param_shard_metadata,
model=self.model,
model=self.unwrapped_model,
)

# Init learning rate scheduler
Expand Down Expand Up @@ -470,7 +470,9 @@ def train(
def training_step(
self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]
) -> Tuple[Iterable[Dict], Optional[torch.Tensor]]:
before_tbi_sanity_checks(self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator)
before_tbi_sanity_checks(
self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator, self.lr_scheduler
)

if self.iteration_step < 5:
log_memory(logger=logger)
Expand Down

0 comments on commit 42040ae

Please sign in to comment.