Skip to content

Commit

Permalink
Merge pull request #246 from huggingface/nouamane/optim-state-cpu-off…
Browse files Browse the repository at this point in the history
…load

Optimize memory when loading checkpoint
  • Loading branch information
NouamaneTazi authored Nov 25, 2024
2 parents 42040ae + f00a380 commit e694f6d
Show file tree
Hide file tree
Showing 10 changed files with 249 additions and 83 deletions.
134 changes: 132 additions & 2 deletions src/nanotron/optim/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,28 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Set, TypeVar
from collections import defaultdict
from copy import deepcopy
from itertools import chain
from typing import (
Any,
Callable,
DefaultDict,
Dict,
Hashable,
Iterable,
List,
Optional,
Set,
Tuple,
TypeVar,
Union,
)

import torch
from typing_extensions import TypeAlias

Args: TypeAlias = Tuple[Any, ...]
Kwargs: TypeAlias = Dict[str, Any]
StateDict: TypeAlias = Dict[str, Any]


class BaseOptimizer(ABC):
Expand Down Expand Up @@ -34,7 +55,7 @@ def state_dict(self) -> dict:
...

@abstractmethod
def load_state_dict(self, state_dict: dict) -> None:
def load_state_dict(self, state_dict: dict, map_location: Optional[Union[str, torch.device]] = None) -> None:
...

@abstractmethod
Expand All @@ -46,3 +67,112 @@ def inherit_from(self, cls) -> bool:


Optimizer = TypeVar("Optimizer", BaseOptimizer, torch.optim.Optimizer)


# Modified from torch.optim.Optimizer._process_value_according_to_param_policy
@staticmethod
def _process_value_according_to_param_policy(
param: torch.Tensor,
value: torch.Tensor,
param_id: int,
param_groups: List[Dict[Any, Any]],
map_location: Optional[Union[str, torch.device]],
key: Hashable = None,
) -> torch.Tensor:
# If map_location is specified, use it instead of param.device
target_device = map_location if map_location is not None else param.device

fused = False
capturable = False
assert param_groups is not None
for pg in param_groups:
if param_id in pg["params"]:
fused = pg["fused"] if "fused" in pg else False
capturable = pg["capturable"] if "capturable" in pg else False
break

if key == "step":
if capturable or fused:
return value.to(dtype=torch.float32, device=target_device)
else:
return value
else:
if param.is_floating_point():
return value.to(dtype=param.dtype, device=target_device)
else:
return value.to(device=target_device)


# Modified from torch.optim.Optimizer.load_state_dict
@torch._disable_dynamo
def custom_load_state_dict(self, state_dict: StateDict, map_location: Union[str, torch.device]) -> None:
r"""Loads the optimizer state.
Args:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`.
map_location (str or torch.device, optional): Device where to load the optimizer states.
If None, states will be loaded to the same device as their corresponding parameters.
Default: None
"""

# shallow copy, to be consistent with module API
state_dict = state_dict.copy()

for pre_hook in self._optimizer_load_state_dict_pre_hooks.values():
hook_result = pre_hook(self, state_dict)
if hook_result is not None:
state_dict = hook_result

# Validate the state_dict
groups = self.param_groups

# Deepcopy as we write into saved_groups later to update state
saved_groups = deepcopy(state_dict["param_groups"])

if len(groups) != len(saved_groups):
raise ValueError("loaded state dict has a different number of " "parameter groups")
param_lens = (len(g["params"]) for g in groups)
saved_lens = (len(g["params"]) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError(
"loaded state dict contains a parameter group " "that doesn't match the size of optimizer's group"
)

# Update the state
id_map = dict(
zip(chain.from_iterable(g["params"] for g in saved_groups), chain.from_iterable(g["params"] for g in groups))
)

def _cast(param, value, param_id=None, param_groups=None, key=None):
r"""Make a deep copy of value, casting all tensors to device of param."""
if isinstance(value, torch.Tensor):
return _process_value_according_to_param_policy(param, value, param_id, param_groups, map_location, key)
elif isinstance(value, dict):
return {k: _cast(param, v, param_id=param_id, param_groups=param_groups, key=k) for k, v in value.items()}
elif isinstance(value, Iterable):
return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value)
else:
return value

# Copy state assigned to params (and cast tensors to appropriate types).
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict)
for k, v in state_dict["state"].items():
if k in id_map:
param = id_map[k]
state[param] = _cast(param, v, param_id=k, param_groups=state_dict["param_groups"])
else:
state[k] = v

# Update parameter groups, setting their 'params' value
def update_group(group: Dict[str, Any], new_group: Dict[str, Any]) -> Dict[str, Any]:
new_group["params"] = group["params"]
return new_group

param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__setstate__({"state": state, "param_groups": param_groups})

for post_hook in self._optimizer_load_state_dict_post_hooks.values():
post_hook(self)
23 changes: 18 additions & 5 deletions src/nanotron/optim/inherit_from_other_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,29 @@
from functools import cache
from typing import Callable, Dict, Optional, Set
from typing import Callable, Dict, Optional, Set, Union

import torch

from nanotron.optim.base import BaseOptimizer, Optimizer
from nanotron.optim.base import BaseOptimizer, Optimizer, custom_load_state_dict


class InheritFromOtherOptimizer(BaseOptimizer):
def __init__(self, optimizer: Optimizer, id_to_name: Dict[int, str]):
self.optimizer: Optimizer = optimizer
self.id_to_name = id_to_name

# if self.optimizer is from torch we replace load_state_dict with the one from torch
if isinstance(optimizer, torch.optim.Optimizer):
# Replace the load_state_dict method with our custom implementation that enables CPU offload
original_load_state_dict = optimizer.load_state_dict
optimizer.load_state_dict = (
lambda state_dict, map_location=None: custom_load_state_dict(
optimizer, state_dict, map_location=map_location
)
if map_location is not None
else original_load_state_dict(state_dict)
)

self.optimizer: Optimizer = optimizer

def __getstate__(self):
return self.optimizer.__getstate__()

Expand All @@ -33,8 +46,8 @@ def state_dict_additional_keys(self) -> Set[str]:
def state_dict(self) -> dict:
return self.optimizer.state_dict()

def load_state_dict(self, state_dict: dict) -> None:
return self.optimizer.load_state_dict(state_dict)
def load_state_dict(self, state_dict: dict, map_location: Optional[Union[str, torch.device]] = None) -> None:
return self.optimizer.load_state_dict(state_dict, map_location=map_location)

def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
return self.optimizer.step(closure=closure)
Expand Down
12 changes: 7 additions & 5 deletions src/nanotron/optim/named_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, Iterable, Tuple, Union
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -58,17 +58,19 @@ def state_dict(self) -> dict:
}
return optim_state_dict

def load_state_dict(self, state_dict: dict) -> None:
def load_state_dict(self, state_dict: dict, map_location: Optional[Union[str, torch.device]] = None) -> None:
assert set(self.id_to_name.values()) == set(
state_dict["names"].values()
), f"Elements don't match:\n - Elements in `self.id_to_name` that aren't in the other one: {set(self.id_to_name.values()) - set(state_dict['names'].values())}\n - Elements in `state_dict[\"names\"]` that aren't in the other one: {set(state_dict['names'].values()) - set(self.id_to_name.values())}"

assert len(state_dict["state"]) == len(
state_dict["names"]
), f"Number of params in loaded state dict ({len(state_dict['state'])}) doesn't match number of names ({len(state_dict['names'])})"
assert len(state_dict["state"]) > 0, "Loading empty state dict"
OPTIMIZER_STATE_KEYS = sorted(state_dict["state"][0].keys() - {"step"})
assert len(state_dict["state"]) == len(state_dict["names"])
for key in OPTIMIZER_STATE_KEYS:
for k, state in state_dict["state"].items():
assert (
key in state
), f"Key {key} not found in state dict: {state} which corresponds to param_name: {state_dict['names'][k]}"

return super().load_state_dict(state_dict)
return super().load_state_dict(state_dict, map_location=map_location)
4 changes: 2 additions & 2 deletions src/nanotron/optim/optimizer_from_gradient_accumulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def state_dict(self) -> dict:
state_dict["gradient_accumulator"] = self.gradient_accumulator.state_dict()
return state_dict

def load_state_dict(self, state_dict: dict) -> None:
def load_state_dict(self, state_dict: dict, map_location: Optional[Union[str, torch.device]] = None) -> None:
gradient_accumulator_state_dict = state_dict.pop("gradient_accumulator")
super().load_state_dict(state_dict)
super().load_state_dict(state_dict, map_location=map_location)
self.gradient_accumulator.load_state_dict(gradient_accumulator_state_dict)
11 changes: 7 additions & 4 deletions src/nanotron/sanity_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def before_optim_step_sanity_checks(
parallel_context: ParallelContext,
unwrapped_model: NanotronModel,
grad_accumulator: GradientAccumulator,
optimizer: optim.BaseOptimizer,
) -> None:
if not config.general.ignore_sanity_checks:
# SANITY CHECK: Test tied weights gradients are synchronized
Expand Down Expand Up @@ -232,6 +233,9 @@ def before_optim_step_sanity_checks(
msg=lambda err: f"[Before optimizer step] Tied weights {name} are not synchronized. {err}",
)

# SANITY CHECK: Check that optimizer states are synchronized across DP
check_optim_state_in_sync(optimizer.state_dict(), parallel_context.dp_pg)

# SANITY CHECK: run model specific sanity checks
unwrapped_model.before_optim_step_sanity_checks()

Expand Down Expand Up @@ -259,12 +263,11 @@ def after_optim_step_sanity_checks(
unwrapped_model.after_optim_step_sanity_checks()


def check_optim_state_in_sync(optimizer: optim.BaseOptimizer, pg: dist.ProcessGroup):
for _, optim_state in sorted(optimizer.state_dict()["state"].items(), key=lambda x: x[0]):
def check_optim_state_in_sync(optim_state_dict: dict, pg: dist.ProcessGroup):
for _, optim_state in sorted(optim_state_dict["state"].items(), key=lambda x: x[0]):
for name, tensor in optim_state.items():
if name == "step":
tensor = tensor.to("cuda")

continue
assert_tensor_synced_across_pg(
tensor=tensor, pg=pg, msg=lambda err: f"{name} are not synced across DP {err}"
)
1 change: 1 addition & 0 deletions src/nanotron/serialize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from nanotron.serialize.main import *
from nanotron.serialize.optimizer import *
from nanotron.serialize.random import *
from nanotron.serialize.weights import *
44 changes: 2 additions & 42 deletions src/nanotron/serialize/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
from datasets.download.streaming_download_manager import xPath
from torch import nn
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import LambdaLR

from nanotron import distributed as dist
Expand All @@ -21,14 +20,12 @@
assert_tensor_synced_across_pg,
check_optim_state_in_sync,
)
from nanotron.serialize.metadata import CheckpointMetadata, TrainingMetadata, load_meta, save_meta
from nanotron.serialize.metadata import TrainingMetadata, save_meta
from nanotron.serialize.optimizer import (
load_lr_scheduler,
load_optimizer,
save_lr_scheduler,
save_optimizer,
)
from nanotron.serialize.weights import load_weights, save_weights
from nanotron.serialize.weights import save_weights

"""
We're going to use safetensors. The reason is that loading segments is going to be much easier
Expand Down Expand Up @@ -206,43 +203,6 @@ def save(
dist.barrier(parallel_context.world_pg)


def load(
model: nn.Module,
optimizer: optim.BaseOptimizer,
lr_scheduler,
parallel_context: ParallelContext,
root_folder: Path,
) -> CheckpointMetadata:
"""
Load checkpoint, raise if checkpoint is assumed corrupted. Inplace updates `model` and `optimizer` to have the newest parameters.
TODO @thomasw21: Make this topology agnostic
:param filepath: Path
:return:
"""
checkpoint_metadata = load_meta(parallel_context=parallel_context, root_folder=root_folder)
load_weights(model=model, parallel_context=parallel_context, root_folder=root_folder)

# SANITY CHECK: assert that optimizer's named_params still point to model's params (check only the first one)
if isinstance(optimizer, optim.ZeroDistributedOptimizer):
if (
len(optimizer.zero_named_param_groups) > 0
and len(optimizer.zero_named_param_groups[0]["named_params"]) > 0
):
optim_model_param_name, optim_model_param = optimizer.zero_named_param_groups[0]["named_params"][0]
if isinstance(model, DistributedDataParallel):
optim_model_param_name = f"module.{optim_model_param_name}"
param = next(p for n, p in model.named_parameters() if n == optim_model_param_name)
assert param.data_ptr() == optim_model_param.data_ptr()

load_optimizer(optimizer=optimizer, parallel_context=parallel_context, root_folder=root_folder)
load_lr_scheduler(
lr_scheduler=lr_scheduler,
root_folder=root_folder,
)
return checkpoint_metadata


def parse_ckpt_path(config: Config, parallel_context: ParallelContext) -> Optional[Path]:
"""Parse checkpoint path from config and download checkpoint from S3 if needed.
Expand Down
Loading

0 comments on commit e694f6d

Please sign in to comment.