-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Order of partitioning in universal checkpoint #5283
Comments
@tjruwase Thanks again for your team's work! After I tried this, I found that in rare cases, the number order is also incorrect. We may need to construct a linked matrix to find the sharding order. |
The conversion from a regular checkpoint to universal one relies on sorting of zero checkpoint files to merge sharded optimizer states. This merge can silently produce wrong results as the sorting is in alphabetical order. The merging logic assumes that files are given in this order. 1. pp_index=0 tp_index=0 dp_index=0 2. pp_index=0 tp_index=0 dp_index=1 ... The optimizer state of a parameter can be sharded across multiple ranks. If it is sharded across dp_index 9-11, the files will be - bf16_zero_pp_rank_9_mp_rank_00_optim_states.pt - bf16_zero_pp_rank_10_mp_rank_00_optim_states.pt - bf16_zero_pp_rank_11_mp_rank_00_optim_states.pt As they are sorted in alphabetical order, the script merges the sharded fragment in the order of [10, 11, 9]. This PR fixes this sort to extracts dp ranks in files and sort the files treating the ranks as numbers. Fix #5283 Co-authored-by: Olatunji Ruwase <[email protected]>
@tohtana I upload the linked matrix in JSON format, and we can find that |
@tohtana Looking forward to your reply. I ran 7B model using ZeRO2 with 128 A100s. Let me know if you need anything else. |
@tohtana Thanks for your effort! During the ds_to_universal conversion process (https://github.com/tohtana/DeepSpeed/blob/master/deepspeed/checkpoint/ds_to_universal.py#L141), we would have the corresponding parameter fragments saved based on parameter names. In the next step, we will merge the parameter fragments, and at this point there is a problem with the order (https://github.com/tohtana/DeepSpeed/blob/master/deepspeed/checkpoint/ds_to_universal.py#L155). In the order here, we find that it is not in alphabetical order, nor in numerical order, so the linked_matrix approach mentioned above is proposed. (discussion here: #4516 (comment)) For example: for |
Thank you for the clarification. I think I figured out the issue. Let me consider how we can fix it. |
@tjruwase the PR sorts the tensor partitions numerically before merging them into a single optimizer state. However, we also need to ensure that the merged optimizer state (fp32_tensors) aligns consistently with the model parameters, as currently they do not in my case. Unfortunately, it's challenging to create a demo code for this issue. Could you share your thoughts on this? Here are the relevant files: |
@tohtana I've seen some inconsistencies with my previous understanding, would you might to take a look?
Maybe it is not easy to get the path from the dp index Shall we use the similar method like here?https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/utils/zero_to_fp32.py#L77C1-L84C22 |
I apologize for the confusion. The main issue stemmed from non-corresponding partitions, hence the proposal for a linked matrix method. If the changes in DeepSpeed #5342 to sort the files in get_zero_files() are implemented, I believe that will resolve the issue. Once the new PR DeepSpeed #5395 is merged, I think we can close the ticket. |
…5395) This PR resolves the issue reported in #5283. To resolve the issue, we sort files of sharded optimizer states based on DP indices. --------- Co-authored-by: Olatunji Ruwase <[email protected]>
The conversion from a regular checkpoint to universal one relies on sorting of zero checkpoint files to merge sharded optimizer states. This merge can silently produce wrong results as the sorting is in alphabetical order. The merging logic assumes that files are given in this order. 1. pp_index=0 tp_index=0 dp_index=0 2. pp_index=0 tp_index=0 dp_index=1 ... The optimizer state of a parameter can be sharded across multiple ranks. If it is sharded across dp_index 9-11, the files will be - bf16_zero_pp_rank_9_mp_rank_00_optim_states.pt - bf16_zero_pp_rank_10_mp_rank_00_optim_states.pt - bf16_zero_pp_rank_11_mp_rank_00_optim_states.pt As they are sorted in alphabetical order, the script merges the sharded fragment in the order of [10, 11, 9]. This PR fixes this sort to extracts dp ranks in files and sort the files treating the ranks as numbers. Fix microsoft#5283 Co-authored-by: Olatunji Ruwase <[email protected]>
…icrosoft#5395) This PR resolves the issue reported in microsoft#5283. To resolve the issue, we sort files of sharded optimizer states based on DP indices. --------- Co-authored-by: Olatunji Ruwase <[email protected]>
…icrosoft#5395) This PR resolves the issue reported in microsoft#5283. To resolve the issue, we sort files of sharded optimizer states based on DP indices. --------- Co-authored-by: Olatunji Ruwase <[email protected]>
The conversion from a regular checkpoint to universal one relies on sorting of zero checkpoint files to merge sharded optimizer states. This merge can silently produce wrong results as the sorting is in alphabetical order. The merging logic assumes that files are given in this order. 1. pp_index=0 tp_index=0 dp_index=0 2. pp_index=0 tp_index=0 dp_index=1 ... The optimizer state of a parameter can be sharded across multiple ranks. If it is sharded across dp_index 9-11, the files will be - bf16_zero_pp_rank_9_mp_rank_00_optim_states.pt - bf16_zero_pp_rank_10_mp_rank_00_optim_states.pt - bf16_zero_pp_rank_11_mp_rank_00_optim_states.pt As they are sorted in alphabetical order, the script merges the sharded fragment in the order of [10, 11, 9]. This PR fixes this sort to extracts dp ranks in files and sort the files treating the ranks as numbers. Fix microsoft#5283 Co-authored-by: Olatunji Ruwase <[email protected]>
…icrosoft#5395) This PR resolves the issue reported in microsoft#5283. To resolve the issue, we sort files of sharded optimizer states based on DP indices. --------- Co-authored-by: Olatunji Ruwase <[email protected]>
Describe the bug
A clear and concise description of what the bug is.
Background:
Automatic adjustment of ZeRO's optimizer state partitioning with a new world size is not currently supported.
Issues:
I use the code from PR #4516 to convert a DeepSpeed checkpoint without TP and PP (128 ranks) to another DeepSpeed checkpoint (32 ranks). In my testing (ZeRO2), I've found that merging in order isn't particularly correct.
In this example, I found that it wasn't the alphabetical order or number order.
Expected behavior
We need to get the linked matrix for partitioning, so that we can get the correct order. Maybe something look like this:
System info (please complete the following information):
The text was updated successfully, but these errors were encountered: