Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Order of partitioning in universal checkpoint #5283

Closed
rgtjf opened this issue Mar 15, 2024 · 14 comments · Fixed by #5342
Closed

[BUG] Order of partitioning in universal checkpoint #5283

rgtjf opened this issue Mar 15, 2024 · 14 comments · Fixed by #5342
Assignees
Labels
bug Something isn't working training

Comments

@rgtjf
Copy link

rgtjf commented Mar 15, 2024

Describe the bug
A clear and concise description of what the bug is.

Background:
Automatic adjustment of ZeRO's optimizer state partitioning with a new world size is not currently supported.

Issues:
I use the code from PR #4516 to convert a DeepSpeed checkpoint without TP and PP (128 ranks) to another DeepSpeed checkpoint (32 ranks). In my testing (ZeRO2), I've found that merging in order isn't particularly correct.

def check_mp_equal_to_fp32(args):
    output_folder = "./output"

    mp_sd = torch.load(
        os.path.join(output_folder, "output", "mp_rank_00_model_states.pt"),
        map_location=torch.device("cpu"),
    )
    zero_output_folder = os.path.join(output_folder, "zero")
    tensor_name_paths = sorted(glob.glob(f"{zero_output_folder}/*"))
    for tensor_name_path in tensor_name_paths:
        if "model" not in tensor_name_path:
            continue
        tensor_name = os.path.basename(tensor_name_path)
        fp32 = torch.load(os.path.join(tensor_name_path, "fp32.pt"))["param"].to(mp_sd["module"][tensor_name])
        torch.testing.assert_allclose(fp32, mp_sd["module"][tensor_name], msg=f"{tensor_name}, fp32: \n{fp32}, mp_sd: \n{mp_sd['module'][tensor_name]}")

In this example, I found that it wasn't the alphabetical order or number order.

Expected behavior
We need to get the linked matrix for partitioning, so that we can get the correct order. Maybe something look like this:

def _build_linked_matrix(args: argparse.Namespace, temp_dir: str, mp_sd: dict):
    linked_matrix = defaultdict(list)
    reversed_linked_matrix = defaultdict(list)

    for tensor_name, mp_tensor in mp_sd["module"].items():
        fp32_paths = sorted(glob.glob(f"{temp_dir}/{tensor_name}/0/fp32.*"))
        fp32_tensors = []
        for fp32_path in fp32_paths:
            partition_idx = fp32_path.split("/")[-1].split(".")[-1]
            fp32_tensor = torch.load(os.path.join(fp32_path))
            fp32_tensors.append((partition_idx, fp32_tensor))

        flat_tensor = mp_tensor.flatten()
        visited = [False for _ in range(len(fp32_tensors))]
        offset = 0
        for merge_cnt in range(len(fp32_tensors)):
            for index, (partition_idx, fp32_tensor) in enumerate(fp32_tensors):
                numel = fp32_tensor.numel()
                if visited[index] is False:
                    cloned_flat_tensor = copy.deepcopy(flat_tensor)
                    cloned_flat_tensor[offset : offset + numel] = fp32_tensor.to(
                        cloned_flat_tensor
                    )
                    if torch.allclose(flat_tensor, cloned_flat_tensor):
                        visited[index] = True
                        # tensor_name -> [(partition_idx, numel)] in temp
                        linked_matrix[tensor_name].append((partition_idx, numel))
                        # fp32.idx -> [(tensor_name, offset, numel)]
                        # it is ordered because each time we merge one partition.
                        reversed_linked_matrix[partition_idx].append(
                            (tensor_name, offset, numel)
                        )
                        offset += numel
        print(tensor_name, linked_matrix[tensor_name], mp_tensor.shape)

    # linked_matrix["model.transformer_encoder.layers.31.linear1.weight"]
    # Out:
    # [('018' 54715476), ('019', 12393388)]

    # reverse_linked_matrix['019']
    # Out:
    # [('model.transformer_encoder.layers.31.linear1.weight', 54715476, 12393388),
    #  ('model.transformer_encoder.layers.31.linear1.bias', 0, 16384),
    #  ('model.transformer_encoder.layers.31.linear2.weight', 0, 46882008)]
    # => 54715476 is the offset of 12393388 elements in layers.31.linear1.weight
    print(linked_matrix, reversed_linked_matrix)
    torch.save(
        (linked_matrix, reversed_linked_matrix),
        os.path.join(args.input_folder, "linked_matrix.pt"),
    )
    return linked_matrix, reversed_linked_matrix

System info (please complete the following information):

  • OS: Ubuntu 22.04
  • GPU count and types: A100
  • Python version: 3.10
  • Any other relevant info about your setup
@tjruwase
Copy link
Contributor

tjruwase commented Apr 1, 2024

@rgtjf, apologies for delay on this. Can you please check this PR #5342?

@rgtjf
Copy link
Author

rgtjf commented Apr 2, 2024

@tjruwase Thanks again for your team's work! After I tried this, I found that in rare cases, the number order is also incorrect. We may need to construct a linked matrix to find the sharding order.

@tjruwase tjruwase assigned tohtana and unassigned samadejacobs Apr 2, 2024
@tjruwase
Copy link
Contributor

tjruwase commented Apr 2, 2024

@rgtjf, thanks for testing and sharing feedback so quickly. Are you able to provide more details on the failing corner cases?

@tohtana, FYI. Can you please check the linked matrix solution proposed by @rgtjf, to see if that will work?

github-merge-queue bot pushed a commit that referenced this issue Apr 2, 2024
The conversion from a regular checkpoint to universal one relies on
sorting of zero checkpoint files to merge sharded optimizer states. This
merge can silently produce wrong results as the sorting is in
alphabetical order.

The merging logic assumes that files are given in this order.
1. pp_index=0 tp_index=0 dp_index=0
2. pp_index=0 tp_index=0 dp_index=1
...

The optimizer state of a parameter can be sharded across multiple ranks.
If it is sharded across dp_index 9-11, the files will be
- bf16_zero_pp_rank_9_mp_rank_00_optim_states.pt
- bf16_zero_pp_rank_10_mp_rank_00_optim_states.pt
- bf16_zero_pp_rank_11_mp_rank_00_optim_states.pt 
 
As they are sorted in alphabetical order, the script merges the sharded
fragment in the order of [10, 11, 9].
This PR fixes this sort to extracts dp ranks in files and sort the files
treating the ranks as numbers.

Fix #5283

Co-authored-by: Olatunji Ruwase <[email protected]>
@rgtjf
Copy link
Author

rgtjf commented Apr 4, 2024

@tohtana I upload the linked matrix in JSON format, and we can find that "model.transformer_encoder.layers.26.linear1.weight": [["126", 53654416], ["001", 13454448]] which means that the first part of fp32_tensors in optimizers is stored in rank 126 and the second part is in rank 1.

linked_matrix.json
reversed_linked_matrix.json

@rgtjf
Copy link
Author

rgtjf commented Apr 9, 2024

@tohtana Looking forward to your reply. I ran 7B model using ZeRO2 with 128 A100s. Let me know if you need anything else.

@tohtana tohtana reopened this Apr 9, 2024
@tohtana
Copy link
Contributor

tohtana commented Apr 9, 2024

Hi @rgtjf, sorry for my late reply.

Can you elaborate the case where #5342 fails to sort the sharded parameters?
I don't think #5342 is a robust solution. I would be happy to replace it if we have a better approach.

@rgtjf
Copy link
Author

rgtjf commented Apr 9, 2024

@tohtana Thanks for your effort!

During the ds_to_universal conversion process (https://github.com/tohtana/DeepSpeed/blob/master/deepspeed/checkpoint/ds_to_universal.py#L141), we would have the corresponding parameter fragments saved based on parameter names. In the next step, we will merge the parameter fragments, and at this point there is a problem with the order (https://github.com/tohtana/DeepSpeed/blob/master/deepspeed/checkpoint/ds_to_universal.py#L155).

In the order here, we find that it is not in alphabetical order, nor in numerical order, so the linked_matrix approach mentioned above is proposed. (discussion here: #4516 (comment))

For example: for model.transformer_encoder.layers.26.linear1.weight (mlp up_proj linear layer), which has a total of 67108864 = 4096*16384 elements, it may first assign 53654416 elements on rank 126, then on 13454448 elements are allocated on rank1 from my observation. This is different from the numerical order.

@tohtana
Copy link
Contributor

tohtana commented Apr 10, 2024

Thank you for the clarification. I think I figured out the issue. Let me consider how we can fix it.

@tohtana
Copy link
Contributor

tohtana commented Apr 10, 2024

@rgtjf I opened a PR to resolve this issue.

This is more ad-hoc that you proposed, but I think it can solve the sorting issue. Do you have any thought?

@rgtjf
Copy link
Author

rgtjf commented Apr 12, 2024

@tjruwase the PR sorts the tensor partitions numerically before merging them into a single optimizer state. However, we also need to ensure that the merged optimizer state (fp32_tensors) aligns consistently with the model parameters, as currently they do not in my case. Unfortunately, it's challenging to create a demo code for this issue. Could you share your thoughts on this?

Here are the relevant files:

@tohtana
Copy link
Contributor

tohtana commented Apr 15, 2024

@rgtjf When creating a universal checkpoint, we take both model parameters and optimizer states from merged FP32 data in zero checkpoints. I'm not sure if I understand the issue you mentioned, but I think the sorting issue here.

@rgtjf
Copy link
Author

rgtjf commented Apr 15, 2024

@tohtana I've seen some inconsistencies with my previous understanding, would you might to take a look?

In [4]: ds_checkpoint.get_zero_files(0, 0, 1)
Out[4]: ['./checkpoint/bf16_zero_pp_rank_100_mp_rank_00_optim_states.pt']

In [5]: ds_checkpoint.get_zero_files(0, 0, 127)
Out[5]: ['./checkpoint/bf16_zero_pp_rank_9_mp_rank_00_optim_states.pt']

In [14]: ds_checkpoint.get_zero_checkpoint_state(0, 0, 1)
Out[14]: 
{'optimizer_state_dict': {'loss_scaler': <deepspeed.runtime.fp16.loss_scaler.LossScaler at 0x7f391db53d90>,
  'dynamic_loss_scale': False,
  'overflow': False,
  'clip_grad': 1.0,
  'base_optimizer_state': {'state': {0: {'step': tensor(900001.),
     'exp_avg': tensor([ 2.8788e-08, -6.7947e-08, -2.8542e-08,  ...,  5.5384e-08,
             -3.2571e-09, -4.0780e-08]),
     'exp_avg_sq': tensor([1.6767e-13, 1.6810e-13, 1.6624e-13,  ..., 1.1828e-13, 1.0733e-13,
             1.1851e-13])}},
   'param_groups': [{'lr': 6.96658312522423e-05,
     'betas': [0.9, 0.95],
     'eps': 1e-08,
     'weight_decay': 0.1,
     'amsgrad': False,
     'foreach': None,
     'maximize': False,
     'capturable': False,
     'differentiable': False,
     'fused': None,
     'initial_lr': 0.0003,
     'params': [0]}]},
  'single_partition_of_fp32_groups': [tensor([ 0.0173, -0.1003, -0.0599,  ...,  0.0796,  0.0342,  0.0743],
          requires_grad=True)],
  'zero_stage': <ZeroStageEnum.gradients: 2>,
  'group_paddings': [0],
  'partition_count': [128],
  'ds_version': '0.12.6',
  'param_slice_mappings': [OrderedDict([('model.transformer_encoder.layers.26.linear1.weight',
                 fragment_address(numel=13454448, start=0)),
                ('model.transformer_encoder.layers.26.linear1.bias',
                 fragment_address(numel=16384, start=13454448)),
                ('model.transformer_encoder.layers.26.linear2.weight',
                 fragment_address(numel=45820948, start=13470832))])]},
 'ds_config': {'bf16': {'enabled': True},
  'train_micro_batch_size_per_gpu': 4,
  'zero_allow_untested_optimizer': True,
  'zero_optimization': {'allgather_bucket_size': 500000000.0,
   'contiguous_gradients': True,
   'overlap_comm': False,
   'reduce_bucket_size': 500000000.0,
   'reduce_scatter': True,
   'stage': 2},
  'gradient_accumulation_steps': 1,
  'gradient_clipping': 1.0},
 'ds_version': '0.12.6'}

Maybe it is not easy to get the path from the dp index

Shall we use the similar method like here?https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/utils/zero_to_fp32.py#L77C1-L84C22

@rgtjf
Copy link
Author

rgtjf commented Apr 19, 2024

I apologize for the confusion. The main issue stemmed from non-corresponding partitions, hence the proposal for a linked matrix method.

If the changes in DeepSpeed #5342 to sort the files in get_zero_files() are implemented, I believe that will resolve the issue. Once the new PR DeepSpeed #5395 is merged, I think we can close the ticket.

Thank you for your efforts! @tohtana @tjruwase

github-merge-queue bot pushed a commit that referenced this issue Apr 19, 2024
…5395)

This PR resolves the issue reported in #5283.
To resolve the issue, we sort files of sharded optimizer states based on
DP indices.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
@tohtana
Copy link
Contributor

tohtana commented Apr 19, 2024

Hi @rgtjf, thank you for the information.
#5396 was merged now. I'm closing this issue, but please feel free to reopen when you still see the issue.

@tohtana tohtana closed this as completed Apr 19, 2024
rraminen pushed a commit to ROCm/DeepSpeed that referenced this issue May 9, 2024
The conversion from a regular checkpoint to universal one relies on
sorting of zero checkpoint files to merge sharded optimizer states. This
merge can silently produce wrong results as the sorting is in
alphabetical order.

The merging logic assumes that files are given in this order.
1. pp_index=0 tp_index=0 dp_index=0
2. pp_index=0 tp_index=0 dp_index=1
...

The optimizer state of a parameter can be sharded across multiple ranks.
If it is sharded across dp_index 9-11, the files will be
- bf16_zero_pp_rank_9_mp_rank_00_optim_states.pt
- bf16_zero_pp_rank_10_mp_rank_00_optim_states.pt
- bf16_zero_pp_rank_11_mp_rank_00_optim_states.pt 
 
As they are sorted in alphabetical order, the script merges the sharded
fragment in the order of [10, 11, 9].
This PR fixes this sort to extracts dp ranks in files and sort the files
treating the ranks as numbers.

Fix microsoft#5283

Co-authored-by: Olatunji Ruwase <[email protected]>
rraminen pushed a commit to ROCm/DeepSpeed that referenced this issue May 9, 2024
…icrosoft#5395)

This PR resolves the issue reported in microsoft#5283.
To resolve the issue, we sort files of sharded optimizer states based on
DP indices.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
umchand pushed a commit to umchand/DeepSpeed that referenced this issue May 20, 2024
…icrosoft#5395)

This PR resolves the issue reported in microsoft#5283.
To resolve the issue, we sort files of sharded optimizer states based on
DP indices.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
dbyoung18 pushed a commit to dbyoung18/DeepSpeed that referenced this issue Jun 11, 2024
The conversion from a regular checkpoint to universal one relies on
sorting of zero checkpoint files to merge sharded optimizer states. This
merge can silently produce wrong results as the sorting is in
alphabetical order.

The merging logic assumes that files are given in this order.
1. pp_index=0 tp_index=0 dp_index=0
2. pp_index=0 tp_index=0 dp_index=1
...

The optimizer state of a parameter can be sharded across multiple ranks.
If it is sharded across dp_index 9-11, the files will be
- bf16_zero_pp_rank_9_mp_rank_00_optim_states.pt
- bf16_zero_pp_rank_10_mp_rank_00_optim_states.pt
- bf16_zero_pp_rank_11_mp_rank_00_optim_states.pt 
 
As they are sorted in alphabetical order, the script merges the sharded
fragment in the order of [10, 11, 9].
This PR fixes this sort to extracts dp ranks in files and sort the files
treating the ranks as numbers.

Fix microsoft#5283

Co-authored-by: Olatunji Ruwase <[email protected]>
dbyoung18 pushed a commit to dbyoung18/DeepSpeed that referenced this issue Jun 11, 2024
…icrosoft#5395)

This PR resolves the issue reported in microsoft#5283.
To resolve the issue, we sort files of sharded optimizer states based on
DP indices.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants