Skip to content

Commit

Permalink
[zero++] Synchronize at the end of secondary partitioning and simplif…
Browse files Browse the repository at this point in the history
…y the logic (microsoft#5216)

## 1. Why?

We have a very long thread investigating [the
issue](microsoft#5059). To
summarize, this is because

a. The 2nd partitioning is asynchronous because it copies
device-to-device from full tensor to 2nd tensor
b. When using prefetching, the all-gather of 2nd tensor can happen
before 2nd partitioning ends. At that moment, the value of 2nd tensor
might contain bad values.


![image](https://github.com/microsoft/DeepSpeed/assets/24364830/ad6ee6a2-8e1e-4214-a0d2-ee5314b252b8)

Also, we found that the logic of copying is wrong and lengthy, so we
simplified it to only two lines.

Kudos to @yundai424, Haowen Ning, @samadejacobs for the investigation
effort.

## 2. What? 

After multiple careful tests, we found patching
`get_accelerator().synchronize()` to ensure all cuda stream finished
before 2nd partitioning can prevent the issue

## 3. Tests

I validated the correctness of the simplification of 2nd partition
logic. The loss is "exactly" the same before and after simplification
under the same random seed.

Before

```
[
  {"loss": 2.0731},
  {"loss": 2.0288},
  {"loss": 1.927},
  {"loss": 1.8347},
  {"loss": 1.8347},
  {"loss": 1.7896},
  {"loss": 1.602},
  {"loss": 1.766},
  {"loss": 1.8751},
  {"loss": 1.6776}
]

```

After

```
[
  {"loss": 2.0731},
  {"loss": 2.0288},
  {"loss": 1.927},
  {"loss": 1.8347},
  {"loss": 1.8347},
  {"loss": 1.7896},
  {"loss": 1.602},
  {"loss": 1.766},
  {"loss": 1.8751},
  {"loss": 1.6776}
]


```

## 4. TODO

We need further investigation on the issue @samadejacobs 
1) Revisit ZeRO-3 prefetch design 
2) Refactor hpz to reuse primary tensor for secondary partition.

---------

Signed-off-by: byhsu <[email protected]>
Co-authored-by: byhsu <[email protected]>
  • Loading branch information
2 people authored and rraminen committed May 9, 2024
1 parent 51d1022 commit fbe7eec
Showing 1 changed file with 9 additions and 12 deletions.
21 changes: 9 additions & 12 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1635,19 +1635,16 @@ def _partition_param_sec(self, param, buffer=None, has_been_updated=False):
secondary_end = secondary_start + secondary_partition_size

one_dim_param = param.contiguous().view(-1)
start = partition_size * self.rank
end = start + partition_size
if start < param.ds_numel and end <= param.ds_numel:
if secondary_start < param.ds_numel and secondary_end <= param.ds_numel:
sec_src_tensor = one_dim_param.narrow(0, secondary_start, secondary_partition_size)
param.ds_secondary_tensor.copy_(sec_src_tensor)

else:
if start < param.ds_numel:
elements_to_copy = param.ds_numel - start
elements_to_copy_sec = elements_to_copy * param.ds_secondary_tensor_num_of_groups
param.ds_secondary_tensor.narrow(0, 0, elements_to_copy_sec).copy_(
one_dim_param.narrow(0, secondary_start, elements_to_copy_sec))
# ds_numel is unpadded, so the last chunk of the secondary tensor might not be secondary_partition_size
sec_numel = param.ds_numel - secondary_start if secondary_end > param.ds_numel else secondary_partition_size

# copy from full tensor to secondary tensor
param.ds_secondary_tensor.narrow(0, 0,
sec_numel).copy_(one_dim_param.narrow(0, secondary_start, sec_numel))

# TODO: This is a temporary fix to avoid the issue that 2nd tensor all-gather happens before 2nd tensor partition is done
get_accelerator().current_stream().synchronize()

print_rank_0(f"{param.ds_id} partitioned type {param.dtype} dev {param.device} shape {param.shape}",
force=False)
Expand Down

0 comments on commit fbe7eec

Please sign in to comment.