-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add APIs to offload states of model, optimizer, and engine #6011
Conversation
…eepSpeed into tohtana/offload_zero_buffers
Hi @tohtana , Thank you for your work. I've been trying the new APIs to test model offloading in a multi-model deployment (e.g., deepspeed-chat) as part of #5620 . Although the API works in offloading a model and reducing GPU memory initially, after bringing the model back and completing the first training iteration (i.e., optimiser states have been updated), I get a
|
Thank you for reporting, @kfertakis! I have an example script showing the usage of the APIs. Can you try this? |
So I tested the issue again with various models and it seems the problem is model-size related as it does not seem to occur for smaller models (i.e., <= 1B params, e.g., gpt2, gpt2-medium) and it does for bigger ones(i.e., OPT-1.3B, mistral-7B). Is there anything I could do to investigate it further and debug it? By the way, I should mention that I'm testing this in a single node, single GPU configuration (i.e., single worker) thus ZeRO3 partitioning should not have to partition data across other workers. I will also test the benchmark you referenced with an artificially larger model size setting. Thanks again. |
Hi @kfertakis, I tried this example with a 4B model but it worked. Can you try this on your environment? |
@tohtana, I wonder if it is useful to expose
Similar to how @kfertakis, would love to get your thoughts as well on whether any of the above would be useful? Thanks! |
Hey, thanks for the comments. @tohtana, I've tried the example you provided and it does seem to work so I'm sharing a fork of the DeepSpeed-Examples repo to showcase the problem. I've modified the DeepSpeed-Chat code to use
this should lead to the @tjruwase thanks for the reference. Current problem aside, I can see how the helper functions can be useful in the future for ensuring consistency. thanks. |
Hi @kfertakis, thank you for sharing the repro. It seems that the actual issue is related to ZeRO3's prefetching. I opened #6557 as a workaround to address this issue. Can you try the branch |
Hi @tohtana, thank you for your work. I tried your branch and the issue seems to be fixed. I will continue testing and raise any new issues but for now, the |
I also wanted to ask if the offloading functionality could be extended to support |
@tjruwase Let me address this by another PR after this one is merged. |
Thank you @kfertakis for validating the fix.
Let me consider how to do this. Please feel free to open a new issue to track it as I am going to merge this PR first. |
Parameters prefetched by ZeRO3 are sometimes not used. This occurs when the actual sub-module execution differs from previous tracing. As a result, the state of the allgather handle for such a parameter remains `INFLIGHT`, causing functions like `empty_partition_cache` to detect it and throw an error. This PR resolves the issue by ensuring that communication finishes and the parameters are freed. As this issue was mentioned in #6011, this includes the change of the branch. We need to merge #6011 first. --------- Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]>
This PR adds an API `deepspeed.runtime.zero.offload_states get_state_devices`, which gets devices of offload states as suggested in this [comment](#6011 (comment)). We could lift this up to `deepspeed.utils` but would need to resolve a circular import: User code -> `deepspeed.utils` -> `deepspeed.utils.offload_states` -> `deepspeed.runtime.zero` -> `deepspeed.runtime.zero.partition_parameters` -> `deepspeed.utils` This will require a significant refactoring as long as we have `OffloadStateTypeEnum` in `deepspeed.runtime.zero`. --------- Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]>
This PR adds the following APIs to offload model, optimizer, and engine states.
Here is the typical usage.
You can selectively offload states to balance the offloading overhead and memory saving.
Performance (4.3B parameters / 4x A100)
python output_table.py
TODO: