Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PR #17440: Inject desired pattern for handling Transpose for fp8 gemm…
… rewrite Imported from GitHub PR #17440 Related to #17276 and #16975. This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue: ``` a = f8e4m3fn[x,y]{0,1} xxx transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0) custom-call(a,...) ``` to ``` a = f8e4m3fn[x,y]{0,1} xxx bt = f8e4m3fn[y,x]{1,0} bitcast(a) transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0) bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1) custom-call(bt.1,...) ``` Copybara import of the project: -- 237c032 by shuw <[email protected]>: Improve TransposeMatrix -- 508cd69 by Shu Wang <[email protected]>: Fix bug of permutation. -- c55e8a9 by shuw <[email protected]>: clang format -- ad0a4ba by Shu Wang <[email protected]>: Add unittest. -- 1d45b4d by Shu Wang <[email protected]>: Remove uncessary space. -- 7837845 by Shu Wang <[email protected]>: Update unittest. -- b479c21 by shuw <[email protected]>: Improve TransposeMatrix Merging this change closes #17440 FUTURE_COPYBARA_INTEGRATE_REVIEW=#17440 from wenscarl:fp8_regulate_transpose b479c21 PiperOrigin-RevId: 680886834
- Loading branch information