Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PR #16975: Add a few related optimization passes for fp8 gemm custom-…
…calls. Imported from GitHub PR #16975 This caused convergence issue for fp8 training, tested on GPT3 models: Before: ``` NETWORK BACKEND MATH SDPA XLA_EXTRAS GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.064 11.019 1571 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 11.015041 [PAX STATUS] step_i: 200, training loss: 11.016165 [PAX STATUS] step_i: 300, training loss: 11.016386 [PAX STATUS] step_i: 400, training loss: 11.014653 [PAX STATUS] step_i: 500, training loss: 11.014734 [PAX STATUS] step_i: 600, training loss: 11.01613 [PAX STATUS] step_i: 700, training loss: 11.009399 [PAX STATUS] step_i: 800, training loss: 11.017071 [PAX STATUS] step_i: 900, training loss: 11.014582 [PAX STATUS] step_i: 1000, training loss: 11.013434 [PAX STATUS] step_i: 1100, training loss: 11.021271 [PAX STATUS] step_i: 1200, training loss: 11.008364 [PAX STATUS] step_i: 1300, training loss: 11.0198145 [PAX STATUS] step_i: 1400, training loss: 11.01253 [PAX STATUS] step_i: 1500, training loss: 11.019016 ``` After: ``` NETWORK BACKEND MATH SDPA GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.020 3.797 1647 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 6.150083 [PAX STATUS] step_i: 200, training loss: 5.8871064 [PAX STATUS] step_i: 300, training loss: 5.4491887 [PAX STATUS] step_i: 400, training loss: 5.6384015 [PAX STATUS] step_i: 500, training loss: 5.273538 [PAX STATUS] step_i: 600, training loss: 5.2011905 [PAX STATUS] step_i: 700, training loss: 4.903013 [PAX STATUS] step_i: 800, training loss: 4.62972 [PAX STATUS] step_i: 900, training loss: 4.507727 [PAX STATUS] step_i: 1000, training loss: 4.625259 [PAX STATUS] step_i: 1100, training loss: 4.428066 [PAX STATUS] step_i: 1200, training loss: 4.252451 [PAX STATUS] step_i: 1300, training loss: 3.8448389 [PAX STATUS] step_i: 1400, training loss: 3.8578327 [PAX STATUS] step_i: 1500, training loss: 3.796958 ``` Copybara import of the project: -- 81af29c by Elfie Guo <[email protected]>: Add a few related optimization pass for fp8 gemm rerwriter. Merging this change closes #16975 FUTURE_COPYBARA_INTEGRATE_REVIEW=#16975 from elfiegg:pass 81af29c PiperOrigin-RevId: 684532401
- Loading branch information