Skip to content

Commit

Permalink
[hotfix] fix unsafe async comm in zero (#4404)
Browse files Browse the repository at this point in the history
* improve stablility of zero

* fix wrong index

* add record stream
  • Loading branch information
Gy-Lu authored Aug 11, 2023
1 parent 6ccecc0 commit d86ddd9
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 20 deletions.
55 changes: 36 additions & 19 deletions colossalai/zero/low_level/bookkeeping/bucket_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,20 @@ class BucketStore(BaseStore):
def __init__(self, torch_pg: ProcessGroup):
super().__init__(torch_pg)

# init and reset
# init
self.current_group_id = 0
self._num_elements_in_bucket = 0
# mapping gardient slices and parameter
self.grad_to_param_mapping = dict()

self._grad_in_bucket = dict()
self._param_list = []
self._padding_size = []
for rank in range(self._world_size):
self._grad_in_bucket[rank] = []

self.reset()
# offset_list records number of tensors in the bucket before each reduction
self.offset_list = [0]

def num_elements_in_bucket(self) -> int:
"""Return the total number of elements in bucket
Expand All @@ -32,6 +37,12 @@ def num_elements_in_bucket(self) -> int:

return self._num_elements_in_bucket

def reset_num_elements_in_bucket(self):
"""Set the number of elements in bucket to zero.
"""

self._num_elements_in_bucket = 0

def add_param_grad(self, group_id: int, param: Tensor, padding_size: int):
"""Add a param to bucket and record the padding size of a param for gradient padding
Expand All @@ -46,28 +57,32 @@ def add_param_grad(self, group_id: int, param: Tensor, padding_size: int):
self._num_elements_in_bucket += (param.numel() + padding_size)
self.current_group_id = group_id

# number of tensors in current bucket
self.offset_list[-1] += 1

def build_grad_in_bucket(self):
"""Orgnize parameters' gradient(padding and split), follows the paramters' splitting method
Data structure of self._grad_in_bucket:
{
rank0: [grad0_rank0, grad1_rank0, ...]
rank1: [grad1_rank1, grad1_rank1, ...]
rank1: [grad0_rank1, grad1_rank1, ...]
}
"""

for param, padding_size in zip(self._param_list, self._padding_size):
with torch.no_grad():
grad = param.grad.detach().flatten()
if padding_size > 0:
grad = torch.nn.functional.pad(grad, [0, padding_size])
grad_list = grad.split(grad.numel() // self._world_size)
for rank in range(self._world_size):
grad_current_rank = grad_list[rank].detach()
self.grad_to_param_mapping[id(grad_current_rank)] = id(param)
self._grad_in_bucket[rank].append(grad_current_rank)
grad = param.grad.clone().detach().flatten()
if padding_size > 0:
with torch.no_grad():
grad = torch.nn.functional.pad(grad.view(-1), [0, padding_size])
grad_list = grad.split(grad.numel() // self._world_size)
for rank in range(self._world_size):
grad_current_rank = grad_list[rank].clone().detach()
self.grad_to_param_mapping[id(grad_current_rank)] = id(param)
self._grad_in_bucket[rank].append(grad_current_rank)
param.grad = None

self.offset_list.append(0)

def get_grad(self) -> Dict:
"""Return the dictionary of gradients slices, of which the keys are ranks
Expand Down Expand Up @@ -104,10 +119,12 @@ def get_param_id_of_grad(self, grad: Tensor) -> int:
return self.grad_to_param_mapping[id(grad)]

def reset(self):
self.grad_to_param_mapping = dict()
self._num_elements_in_bucket = 0
self._param_list = []
self._padding_size = []
self._grad_in_bucket = dict()
"""Reset the bucket storage after reduction, only release the tensors have been reduced
"""
cur_offset = self.offset_list.pop(0)
self._param_list = self._param_list[cur_offset:]
self._padding_size = self._padding_size[cur_offset:]
for _ in range(cur_offset):
del self.grad_to_param_mapping[next(iter(self.grad_to_param_mapping))]
for rank in range(self._world_size):
self._grad_in_bucket[rank] = []
self._grad_in_bucket[rank] = self._grad_in_bucket[rank][cur_offset:]
9 changes: 9 additions & 0 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,19 @@ def _attach_reduction_hook(self):
def _run_reduction(self):
if self._bucket_store.num_elements_in_bucket() > 0:
self._bucket_store.build_grad_in_bucket()

flat_grads = self._bucket_store.get_flatten_grad()
flat_grads /= self._world_size

# ready to add other tensors to bucket
self._bucket_store.reset_num_elements_in_bucket()

if self._overlap_communication:
stream = self._comm_stream
# in case of the memory being reused in the default stream
flat_grads.record_stream(stream)
# waiting for ops in the default stream finishing
stream.wait_stream(torch.cuda.current_stream())
else:
stream = torch.cuda.current_stream()

Expand Down
2 changes: 1 addition & 1 deletion tests/test_zero/test_low_level/test_zero1_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype):
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
overlap_communication=True,
initial_scale=1,
reduce_bucket_size=262144)
reduce_bucket_size=1024 * 1024)

torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)

Expand Down

0 comments on commit d86ddd9

Please sign in to comment.