This repository has been archived by the owner on Aug 7, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Updates to sync_float_amax_history (#211)
Summary: Update docs, make sure this is friendly to dynamo ### Perf PyTorch Version | float8 Version | Eager Iterations per Second | Compile -- | -- | -- | -- Nightly | Main | 1.15 it/s | 2.10 it/s Nightly | This PR | 1.16 it/s | 2.27 it/s Trace | Compile URL | Eager -- | -- | -- This PR | https://fburl.com/753ztao4 | https://fburl.com/34yftzao Main | https://fburl.com/a0gh9iof | https://fburl.com/u9c4ilmp ### Things I have done/changed #### Commit 1 - [x] We previously had an `fp8_classes` argument that would be passed in, this was to enable working with the separate TP/SP classes, since we plan to have Dtensor be the solution I am removing for now. - [x] I put the child.amax_and_scale_synced module mutation under the enable_amax_init flag, this seemed to be causing graphbreaks cause of the module mutation #### Commit 2 - [x] We previously had all the history buffers be scaler tensors. This meant that to construct the combined tensor we needed to call torch.Tensor which was causing a HtoD sync under torch.compile. I needed to added a single dimension of size 1 and pipe that through all the places. - [x] Note that this meant we needed to update the to_hp to send back to original precision because [line](f3630d0#diff-94b99416a4df6d75c548de330c1f71505e830b3afff114213d131cf2620597efR57-R59) the scale upcasts the _data tensor #### Commit 3 - [x] Rewrote the sync function to do the torch.roll() on all the histories at once - side note not sure if this is more expensive than to clones since we really dont care about the wrapping behavior - [x] Same for generating the new scales from the grouped histories ##### Things to do - There is still two loops and those are for mutating the the actual module values, not sure if there is another way around this.. - Going to try the functional collectives Pull Request resolved: #211 Reviewed By: awgu Differential Revision: D53779974 Pulled By: drisspg fbshipit-source-id: 0a07f247d41d58f1934a69d194f81c5dea230eb1
- Loading branch information
1 parent
0af8433
commit 956195b
Showing
4 changed files
with
162 additions
and
97 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
I think adding this
if
brokeenable_amax_init = False
because we still check and raise:float8_experimental/float8_experimental/float8_linear.py
Lines 271 to 273 in 956195b
If I understand correctly, checking if amax and scale are synced is orthogonal to whether we enable amax init in general. Only on the 1st iteration, if we disable amax init, then we can assume that the amax and scale are already synced?