Skip to content

Commit

Permalink
[checkpointio] unsharded optimizer checkpoint for Gemini using all_ga…
Browse files Browse the repository at this point in the history
…ther
  • Loading branch information
Fridge003 committed Jul 7, 2023
1 parent 9775e9b commit 6cb7920
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 104 deletions.
3 changes: 0 additions & 3 deletions colossalai/nn/optimizer/cpu_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,3 @@ def step(self, closure=None, div_scale: float = -1):
raise RuntimeError
self._post_step()
return loss

def get_state_names(self):
return ['step', 'exp_avg', 'exp_avg_sq']
3 changes: 0 additions & 3 deletions colossalai/nn/optimizer/fused_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,3 @@ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_no
group['weight_decay'], div_scale)

return loss

def get_state_names(self):
return ['step', 'exp_avg', 'exp_avg_sq']
3 changes: 0 additions & 3 deletions colossalai/nn/optimizer/hybrid_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,3 @@ def step(self, closure=None, div_scale: float = -1):
bias_correction, group['weight_decay'], div_scale)
self._post_step()
return loss

def get_state_names(self):
return ['step', 'exp_avg', 'exp_avg_sq']
3 changes: 2 additions & 1 deletion colossalai/testing/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):


def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):
assert len(list(d1.keys())) == len(list(d2.keys()))
assert len(list(d1.keys())) == len(list(d2.keys())), \
f"Number of keys unequal: {len(list(d1.keys()))} vs {len(list(d2.keys()))}"
for k, v1 in d1.items():
assert k in d2
v2 = d2[k]
Expand Down
146 changes: 66 additions & 80 deletions colossalai/zero/gemini/gemini_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import copy
import gc
import math
import warnings
from collections import abc as container_abcs
from typing import Any, Dict, Set, Tuple

import torch
Expand Down Expand Up @@ -348,10 +348,11 @@ def get_offsets(self, param_id: int) -> tuple:
chunk_offset(int): Offset of parameter inside the chunk.
shard_offset(int): Offset of its optimizer state shard
relative to the whole optimizer state.
param_size(int): Length of parameter shard owned by current process.
shard_size(int): Length of parameter shard owned by current process.
'''

assert param_id in self.id_to_fake_params
if param_id not in self.id_to_fake_params:
return -1, -1, -1
fake_param = self.id_to_fake_params[param_id]
chunk = self.param_to_chunk32[fake_param].paired_chunk
param = self.id_to_real_params[param_id]
Expand All @@ -360,15 +361,16 @@ def get_offsets(self, param_id: int) -> tuple:
begin_in_chunk, end_in_chunk = self.param_to_range[fake_param]
chunk_offset = begin_in_chunk
shard_offset = begin_in_chunk + chunk.shard_begin - param_info.offset
param_size = end_in_chunk - begin_in_chunk
shard_size = end_in_chunk - begin_in_chunk
assert chunk_offset >= 0 and shard_offset >= 0

return chunk_offset, shard_offset, param_size
return chunk_offset, shard_offset, shard_size

def collect_states(self, param_id: int) -> dict:
def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict:
"""
Args:
param_id (int): id of the parameter whose state is to be gathered at master rank.
only_rank_0(bool): if True, states will be collected only on master rank, otherwise collected on every rank.
Returns:
collected_states(dict): the gathered optimzier state of parameter with given id
Expand All @@ -379,21 +381,44 @@ def collect_states(self, param_id: int) -> dict:

# Get param & chunk & process group.
param = self.id_to_real_params[param_id]
fake_param = self.id_to_fake_params.get(param_id, None)
chunk = self.chunk_manager.get_chunk(param)
process_group = chunk.torch_pg
rank = dist.get_rank(process_group)
master_rank = 0
state_names = self.optim.get_state_names()

collected_states = {}

# Fetch names of states through all_gather.
local_state_names = None
if fake_param is not None:
local_state_names = list(self.optim.state[fake_param].keys())
gathered_state_names = [None for _ in range(dist.get_world_size(process_group))]
dist.barrier()
dist.all_gather_object(gathered_state_names, local_state_names)
state_names = None
for names in gathered_state_names:
if names is not None:
# Assume different devices share the same set of state names if they have.
state_names = copy.deepcopy(names)
break

# Directly return if this parameter doesn't have optimizer states.
# e.g. parameter freezed/layer dropped
if state_names is None:
return collected_states

# Boolean variable is_collector indicates that whether the current rank
# needs to gather the whole optimizer states.
# Only master rank is collector when only_rank_0 is True.
# Every rank is collector when only_rank_0 is False.
is_collector = (rank == master_rank) or (not only_rank_0)

# If the chunk is kept gathered,
# the parameteres are treated the same as that of those in strict DDP during training.
# So states can be directly fetched from current device.
if chunk.keep_gathered:
assert param_id in self.id_to_fake_params
fake_param = self.id_to_fake_params[param_id]
if rank == master_rank:
if is_collector:
states = self.optim.state[fake_param]
for state_name in state_names:
if state_name == 'step':
Expand All @@ -408,42 +433,8 @@ def collect_states(self, param_id: int) -> dict:
# Check whether the param with given id is managed by current process.
own_param = param_id in self.id_to_fake_params

# Compute position information (offsets) of state shard.
# If current process doesn't control this param, position message should be [rank, -1, -1]
# else it should be [rank, start_offset, end_offset]
state_shard_range = torch.tensor([rank, -1, -1], dtype=torch.int, requires_grad=False).cuda()
if own_param:
_, shard_offset, param_size = self.get_offsets(param_id)
state_shard_range[1] = shard_offset
state_shard_range[2] = shard_offset + param_size

# Ranks other than master send position messages to master.
# The master rank should receive a dict mapping rank number to state shard offsets, in the form of:
# {rank_0: (0, x_1), rank_1: (x_1, x_2), rank2: (x_2, x_3), ... rank_{n-1}: (x_{n-1}, param.numels()) }
dist.barrier(process_group)
range_info = dict()
if rank == master_rank:

# Record if master owns this param.
if state_shard_range[-1] >= 0:
range_info[master_rank] = (state_shard_range[-2].item(), state_shard_range[-1].item())

container_tensor = torch.zeros_like(state_shard_range, dtype=torch.int, requires_grad=False).cuda()
for src in range(0, dist.get_world_size(process_group)):
if src == master_rank:
continue
dist.recv(container_tensor, src, process_group)
# Only keeps valid position messages.
if container_tensor[-1] >= 0:
range_info[container_tensor[0].item()] = (container_tensor[-2].item(), container_tensor[-1].item())
assert len(list(range_info.keys())) > 0

else:
dist.send(state_shard_range, master_rank, process_group)
dist.barrier(process_group)

# Master get prepared for state collecting.
if rank == master_rank:
# Collector gets prepared for state collecting.
if is_collector:
for state_name in state_names:
if state_name == 'step':
# To keep aligned with pytorch, state 'step' is stored as a pytorch tensor with type float32.
Expand All @@ -452,35 +443,33 @@ def collect_states(self, param_id: int) -> dict:
collected_states[state_name] = torch.zeros(param.numel(), dtype=torch.float32,
requires_grad=False).cpu()

# Master starts to receive state tensors from other ranks
# after all the position messages of state shards has been collected.
compacted_states = self.pack_optimizer_states_to_tensor(param_id)
if rank == master_rank:
for src, shard_range in range_info.items():
# Materials for gathering, including compacted state tensors, and the offset of shard inside each state.
compacted_states = self.pack_optimizer_states_to_tensor(param_id, state_names) if own_param else None
_, shard_offset, shard_size = self.get_offsets(param_id)

# If src is master, directly load from compacted_states tensor.
if src == master_rank:
self.load_from_compacted_states(compacted_states, collected_states, shard_range)
continue
# Collectors gather state shards through all_gathering.
gathered_state_shards = [None for _ in range(dist.get_world_size(process_group))]

# If src is not master, receive states from other ranks.
shard_size = shard_range[1] - shard_range[0]
num_states = len(state_names)
container_size = 1 + (num_states - 1) * shard_size if 'step' in state_names \
else num_states * shard_size
container_tensor = torch.zeros(container_size, dtype=torch.float32, requires_grad=False).cuda()
dist.recv(container_tensor, src, process_group)
self.load_from_compacted_states(container_tensor, collected_states, shard_range)
del container_tensor
else:
if own_param:
dist.send(compacted_states, master_rank, process_group)
dist.barrier()
dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size])

del compacted_states
dist.barrier(process_group)
if is_collector:
for state_shard in gathered_state_shards:
compacted_states = state_shard[0]
shard_offset = state_shard[1]
shard_size = state_shard[2]
if compacted_states is None:
continue
self.load_from_compacted_states(compacted_states, collected_states, state_names, shard_offset,
shard_size)

# Clean gathered states
for state_shard in gathered_state_shards:
del state_shard[0]
gc.collect()

# Reshape tensors
if rank == master_rank:
if is_collector:
for state_name, state_tensor in collected_states.items():
if state_tensor.numel() == param.numel():
collected_states[state_name] = torch.reshape(state_tensor, param.shape)
Expand All @@ -489,6 +478,7 @@ def collect_states(self, param_id: int) -> dict:

def pack_optimizer_states_to_tensor(self,
param_id: int,
state_names: list,
device: torch.device = torch.device('cuda'),
dtype: torch.dtype = torch.float32) -> torch.Tensor:
'''
Expand All @@ -502,7 +492,7 @@ def pack_optimizer_states_to_tensor(self,
states = self.optim.state[fake_param]
shard_size = param_range[1] - param_range[0]
compacted_size = 0
for name in self.optim.get_state_names():
for name in state_names:
if name == 'step':
compacted_size += 1
else:
Expand All @@ -526,16 +516,16 @@ def pack_optimizer_states_to_tensor(self,

return compacted_states

def load_from_compacted_states(self, compacted_states: torch.Tensor, collected_states: dict, shard_range: tuple):
def load_from_compacted_states(self, compacted_states: torch.Tensor, collected_states: dict, state_names: list,
shard_start: int, shard_size: int):
'''
Given a tensor carrying compacted optimizer states,
update these states to collected_states.
'''
shard_start, shard_end = shard_range
shard_size = shard_end - shard_start
shard_end = shard_start + shard_size
next_state_offset = 0

for state_name in self.optim.get_state_names():
for state_name in state_names:
if state_name == 'step':
collected_states['step'].data = torch.tensor(compacted_states[next_state_offset].item(),
dtype=torch.float32,
Expand Down Expand Up @@ -564,10 +554,6 @@ def state_dict(self, only_rank_0: bool = True) -> dict:
Warning: This method will gather and return the whole optimizer state_dict,
so it should be called only when memory resources are abundant.
"""

if not only_rank_0:
raise ValueError("False 'only_rank_0' in self.state_dict() method is currently not supported")

state_dict = {}
state_dict['param_groups'] = copy.deepcopy(self.param_groups_backup)

Expand All @@ -589,7 +575,7 @@ def state_dict(self, only_rank_0: bool = True) -> dict:
state_dict['state'] = dict()
for param_id in self.id_to_real_params.keys():
dist.barrier()
state_dict['state'][param_id] = self.collect_states(param_id=param_id)
state_dict['state'][param_id] = self.collect_states(param_id=param_id, only_rank_0=only_rank_0)
return state_dict

def load_param_groups(self, saved_param_groups: list):
Expand Down
5 changes: 3 additions & 2 deletions tests/test_checkpoint_io/test_gemini_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b
def exam_state_dict(placement_policy, shard: bool, model_name: str, size_per_shard: int):
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = lambda x: x.mean()
plugin = GeminiPlugin(placement_policy=placement_policy)
plugin = GeminiPlugin(placement_policy=placement_policy, precision="fp16", initial_scale=(2**14))
booster = Booster(plugin=plugin)

model = model_fn()
Expand Down Expand Up @@ -91,7 +91,8 @@ def exam_state_dict(placement_policy, shard: bool, model_name: str, size_per_sha
new_model.unwrap().state_dict(only_rank_0=False), False)

booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
check_state_dict_equal(optimizer.unwrap().state_dict(only_rank_0=False),
new_optimizer.unwrap().state_dict(only_rank_0=False), False)

# Check the new model/optimizer can successfully run.
data = data_gen_fn()
Expand Down
25 changes: 13 additions & 12 deletions tests/test_checkpoint_io/test_gemini_torch_compability.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str):

(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = lambda x: x.mean()
plugin = GeminiPlugin()
plugin = GeminiPlugin(precision="fp16", initial_scale=(2**14))
booster = Booster(plugin=plugin)

model = model_fn()
Expand Down Expand Up @@ -65,12 +65,7 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str):
new_model.state_dict(), False)

new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
old_state_dict = optimizer.state_dict()
new_state_dict = new_optimizer.state_dict()

# The complete optimizer state_dict of GeminiPlugin is only collected on rank 0
if dist.get_rank() == 0:
check_state_dict_equal(old_state_dict, new_state_dict, False)
check_state_dict_equal(optimizer.unwrap().state_dict(only_rank_0=False), new_optimizer.state_dict(), False)

# Check the new model/optimizer can successfully run.
data = data_gen_fn()
Expand Down Expand Up @@ -134,12 +129,18 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str):
model.state_dict(), False)

new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
new_state_dict = new_optimizer.state_dict()
old_state_dict = optimizer.state_dict()

# The complete optimizer state_dict of GeminiPlugin is only collected on rank 0
if dist.get_rank() == 0:
check_state_dict_equal(new_state_dict, old_state_dict, False)
new_state_dict = new_optimizer.unwrap().state_dict(only_rank_0=False)

# Comparison of param_groups needs special care here,
# since not all hyperparameters in Adam are used by HybridAdam
hyperparameters_to_examine = ['params', 'lr', 'betas', 'eps', 'weight_decay']
for old_group, new_group in zip(old_state_dict['param_groups'], new_state_dict['param_groups']):
for k in hyperparameters_to_examine:
assert k in old_group and k in new_group, \
f"Old group's keys: {list(old_group.keys())}, New group's keys: {list(new_group.keys())}"
assert old_group[k] == new_group[k]
check_state_dict_equal(old_state_dict['state'], new_state_dict['state'], False)

# Check the new model/optimizer can successfully run.
data = data_gen_fn()
Expand Down

0 comments on commit 6cb7920

Please sign in to comment.