Skip to content

Commit

Permalink
[checkpointio] unsharded optimizer checkpoint for Gemini plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
Fridge003 committed Jul 5, 2023
1 parent 190a6ea commit 7e2028e
Show file tree
Hide file tree
Showing 11 changed files with 607 additions and 79 deletions.
78 changes: 52 additions & 26 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,44 +33,40 @@ def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()

def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
"""
Load model from checkpoint with automatic unwrapping.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
return super().load_unsharded_model(model, checkpoint, strict=strict)

def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
"""
Save model to checkpoint but only on master process.
Save sharded model to checkpoint but only on master process.
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
As there is communication when getting state dict, this must be called on all processes.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
# as there is communication when get state dict, this must be called on all processes
state_dict = model.state_dict(only_rank_0=True)
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors)

def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
"""
Save optimizer to checkpoint but only on master process.
Load model from checkpoint with automatic unwrapping.
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
"""
# TODO(ver217): optimizer state dict is sharded
warnings.warn('GeminiPlugin does not support save full optimizer checkpoint now. Save it on every process.')
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)

def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
warnings.warn(
'GeminiPlugin can only load optimizer checkpoint saved by itself with the same number of processes.')
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
super().load_optimizer(optimizer, checkpoint)
super().load_unsharded_model(model, checkpoint, strict=strict)

def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
"""
Save model to checkpoint but only on master process.
Save unsharded optimizer state dict to checkpoint.
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
As there is communication when getting state dict, this must be called on all processes.
The saving process will only be executed by master rank.
"""
state_dict = optimizer.state_dict()
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)
save_state_dict(state_dict, checkpoint, use_safetensors=False)

def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str):
"""
Loading unsharded optimizer from checkpoint file.
For each process, only loading optimizer states of parameters it controls.
"""
super().load_unsharded_optimizer(optimizer, checkpoint)

def save_sharded_model(self,
model: GeminiDDP,
Expand All @@ -82,6 +78,12 @@ def save_sharded_model(self,
"""
Save sharded model
"""
if os.path.isfile(checkpoint_path):
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
return

Path(checkpoint_path).mkdir(parents=True, exist_ok=True)

state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
total_size = 0
Expand Down Expand Up @@ -117,6 +119,30 @@ def load_sharded_model(self,
"""
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)

def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str,
size_per_shard: int):
"""
Save sharded optimizer state dict to checkpoint folder.
As there is communication when getting state dict, this must be called on all processes.
"""
Path(checkpoint).mkdir(parents=True, exist_ok=True)
super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)

def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Path, prefix: str):
"""
Loading sharded optimizer from checkpoint folder, with index file given.
For each process, only loading optimizer states of parameters it controls.
"""
# TODO(Baizhou): To be implemented.
pass

def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
Save model to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)


class GeminiModel(ModelWrapper):

Expand Down Expand Up @@ -219,7 +245,7 @@ def __init__(
min_chunk_size_m: float = 32,
memstats: Optional[MemStats] = None,
gpu_margin_mem_ratio: float = 0.0,
initial_scale: float = 2**32,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
Expand Down
2 changes: 2 additions & 0 deletions colossalai/checkpoint_io/checkpoint_io_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = No
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
"""

index_file_exists, index_file_path = has_index_file(checkpoint)

if Path(checkpoint).is_dir() and not index_file_exists:
Expand Down Expand Up @@ -186,6 +187,7 @@ def save_optimizer(self,
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
"""

if shard:
self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
else:
Expand Down
14 changes: 9 additions & 5 deletions colossalai/checkpoint_io/general_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
shard_model_checkpoint,
shard_optimizer_checkpoint,
sharded_optimizer_loading_epilogue,
unwrap_optimizer,
)

__all__ = ['GeneralCheckpointIO']
Expand Down Expand Up @@ -59,7 +60,7 @@ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, pre

# If optimizer is wrapped, unwrap it.
if isinstance(optimizer, OptimizerWrapper):
optimizer = optimizer.optim
optimizer = unwrap_optimizer(optimizer)

# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
Expand Down Expand Up @@ -96,6 +97,11 @@ def save_sharded_optimizer(
- A group file (pytorch_optim_group.bin) recording information of param_groups
- Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
"""

# If optimizer is wrapped, unwrap it.
if isinstance(optimizer, OptimizerWrapper):
optimizer = unwrap_optimizer(optimizer)

if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Expand All @@ -121,9 +127,8 @@ def save_sharded_optimizer(
shard, current_size = shard_pair
shard_file = get_shard_filename(states_name, idx)
total_size = total_size + current_size
for param_id in shard.keys():
index_file.append_weight_map(str(param_id), shard_file)

for key in shard.keys():
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file)
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)

Expand Down Expand Up @@ -177,7 +182,6 @@ def save_sharded_model(self,
total_size = total_size + shard_pair[1]
for key in shard.keys():
index_file.append_weight_map(key, shard_file)

checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
save_state_dict(shard, checkpoint_file_path, use_safetensors)

Expand Down
22 changes: 18 additions & 4 deletions colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import torch.nn as nn
from torch.optim import Optimizer

from colossalai.interface import OptimizerWrapper
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor.d_tensor import is_distributed_tensor

SAFE_WEIGHTS_NAME = "model.safetensors"
Expand Down Expand Up @@ -88,6 +90,17 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
# ======================================
# Helper functions for saving shard file
# ======================================
def unwrap_optimizer(optimizer: OptimizerWrapper):
'''
Unwrap a wrapped optimizer.
This method should be used before saving/loading it to/from sharded checkpoints.
'''
unwrapped_optim = optimizer.optim
if isinstance(unwrapped_optim, ColossalaiOptimizer):
unwrapped_optim = unwrapped_optim.optim
return unwrapped_optim


def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
Expand All @@ -103,7 +116,7 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024)
weight_size = calculate_tensor_size(weight)

# If this weight is going to tip up over the maximal size, we split.
if current_block_size + weight_size > max_shard_size:
if current_block_size + weight_size > max_shard_size and current_block_size > 0:
ret_block = current_block
ret_block_size = current_block_size
current_block = {}
Expand Down Expand Up @@ -140,9 +153,10 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
isDTensor = False
for state_tensor in state.values():

# When state_tensor is None (e.g., a SGD optimizer with momentum set to 0),
# When state_tensor is not of Tensor class,
# e.g., a SGD optimizer with momentum set to 0 can have None as state
# The calculation of tensor size should be skipped to avoid error.
if state_tensor is None:
if not isinstance(state_tensor, torch.Tensor):
continue

# If the states are stored as DTensors, mark isDTensor as true.
Expand All @@ -152,7 +166,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->

if not isDTensor:

if current_block_size + state_size > max_shard_size:
if current_block_size + state_size > max_shard_size and current_block_size > 0:
ret_block = current_block
ret_block_size = current_block_size
current_block = {}
Expand Down
6 changes: 6 additions & 0 deletions colossalai/interface/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,9 @@ def unscale_grad(self):
"""
raise NotImplementedError(
"The method unscale_grad is only available for optimizers with mixed precision training")

def unwrap(self):
"""
Unwrap the optimizer for checkpoint saving/loading.
"""
return self.optim
3 changes: 3 additions & 0 deletions colossalai/nn/optimizer/cpu_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,6 @@ def step(self, closure=None, div_scale: float = -1):
raise RuntimeError
self._post_step()
return loss

def get_state_names(self):
return ['step', 'exp_avg', 'exp_avg_sq']
3 changes: 3 additions & 0 deletions colossalai/nn/optimizer/fused_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,6 @@ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_no
group['weight_decay'], div_scale)

return loss

def get_state_names(self):
return ['step', 'exp_avg', 'exp_avg_sq']
3 changes: 3 additions & 0 deletions colossalai/nn/optimizer/hybrid_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,6 @@ def step(self, closure=None, div_scale: float = -1):
bias_correction, group['weight_decay'], div_scale)
self._post_step()
return loss

def get_state_names(self):
return ['step', 'exp_avg', 'exp_avg_sq']
47 changes: 31 additions & 16 deletions colossalai/testing/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ def assert_not_equal(a: Tensor, b: Tensor):


def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1e-3):
assert_close(a, b, rtol=rtol, atol=atol)
assert_close(a,
b,
rtol=rtol,
atol=atol,
msg=f"Tensor not close, shape: {a.shape} vs {b.shape}, \
dtype: {a.dtype} vs {b.dtype}")


def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
Expand All @@ -33,25 +38,35 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):


def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):
for k, v in d1.items():
if isinstance(v, dict):
check_state_dict_equal(v, d2[k])
elif isinstance(v, list):
for i in range(len(v)):
if isinstance(v[i], torch.Tensor):
assert len(list(d1.keys())) == len(list(d2.keys()))
for k, v1 in d1.items():
assert k in d2
v2 = d2[k]
if isinstance(v1, dict):
assert isinstance(v2, dict)
check_state_dict_equal(v1, v2, ignore_device)
elif isinstance(v1, list):
assert isinstance(v2, list)
for v1_i, v2_i in zip(v1, v2):
if isinstance(v1_i, torch.Tensor):
assert isinstance(v2_i, torch.Tensor)
if not ignore_device:
v[i] = v[i].to("cpu")
d2[k][i] = d2[k][i].to("cpu")
assert torch.equal(v[i], d2[k][i])
v1_i = v1_i.to("cpu")
v2_i = v2_i.to("cpu")
assert_close_loose(v1_i, v2_i)
elif isinstance(v1_i, dict):
assert isinstance(v2_i, dict)
check_state_dict_equal(v1_i, v2_i, ignore_device)
else:
assert v[i] == d2[k][i]
elif isinstance(v, torch.Tensor):
assert v1_i == v2_i, f"{v1_i} not equals to {v2_i}"
elif isinstance(v1, torch.Tensor):
assert isinstance(v2, torch.Tensor)
if not ignore_device:
v = v.to("cpu")
d2[k] = d2[k].to("cpu")
assert torch.equal(v, d2[k])
v1 = v1.to("cpu")
v2 = v2.to("cpu")
assert_close_loose(v1, v2)
else:
assert v == d2[k]
assert v1 == v2, f"{v1} not equals to {v2}"


def assert_hf_output_close(out1: Any,
Expand Down
Loading

0 comments on commit 7e2028e

Please sign in to comment.