diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 5cf655d8741a..142259c1b7df 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -1635,19 +1635,16 @@ def _partition_param_sec(self, param, buffer=None, has_been_updated=False): secondary_end = secondary_start + secondary_partition_size one_dim_param = param.contiguous().view(-1) - start = partition_size * self.rank - end = start + partition_size - if start < param.ds_numel and end <= param.ds_numel: - if secondary_start < param.ds_numel and secondary_end <= param.ds_numel: - sec_src_tensor = one_dim_param.narrow(0, secondary_start, secondary_partition_size) - param.ds_secondary_tensor.copy_(sec_src_tensor) - else: - if start < param.ds_numel: - elements_to_copy = param.ds_numel - start - elements_to_copy_sec = elements_to_copy * param.ds_secondary_tensor_num_of_groups - param.ds_secondary_tensor.narrow(0, 0, elements_to_copy_sec).copy_( - one_dim_param.narrow(0, secondary_start, elements_to_copy_sec)) + # ds_numel is unpadded, so the last chunk of the secondary tensor might not be secondary_partition_size + sec_numel = param.ds_numel - secondary_start if secondary_end > param.ds_numel else secondary_partition_size + + # copy from full tensor to secondary tensor + param.ds_secondary_tensor.narrow(0, 0, + sec_numel).copy_(one_dim_param.narrow(0, secondary_start, sec_numel)) + + # TODO: This is a temporary fix to avoid the issue that 2nd tensor all-gather happens before 2nd tensor partition is done + get_accelerator().current_stream().synchronize() print_rank_0(f"{param.ds_id} partitioned type {param.dtype} dev {param.device} shape {param.shape}", force=False)