Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tranposing to different layout permutations results in different numerics #17276

Open
elfiegg opened this issue Sep 17, 2024 · 13 comments
Open
Assignees

Comments

@elfiegg
Copy link
Contributor

elfiegg commented Sep 17, 2024

Hello, we stumbled upon a numerical issue for below modules while training fp8 quantizated models.

ENTRY main {
    %p0 = f8e4m3fn[12288,4096]{0,1} parameter(0)
    %b = f8e4m3fn[4096,12288]{1,0} bitcast(%p0)
    %transpose = f8e4m3fn[12288,4096]{1,0} transpose(%b), dimensions={1,0}
    %p1 = f8e4m3fn[4096,16384]{0,1} parameter(1)
    %p2 = f32[] parameter(2)
    %constant_1 = f32[] constant(1)
    ROOT %cublas-gemm.1.0 = (bf16[12288,16384]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[12288,4096]{1,0} %transpose,  f8e4m3fn[4096,16384]{0,1} %p1,f32[] %p2, f32[]%constant_1, f32[]%constant_1, f32[]%constant_1), custom_call_target="__cublas$lt$matmul$f8", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"gemm_backend_config":{"alpha_real":1,"alpha_imag":0,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"precision_config":{"operand_precision":["DEFAULT","DEFAULT"],"algorithm":"ALG_UNSET"},"epilogue":"DEFAULT","damax_output":false,"selected_algorithm":"6","lhs_stride":"50331648","rhs_stride":"67108864","grad_x":false,"grad_y":false},"force_earliest_schedule":false}
  }
 ENTRY main {
    %p0 = f8e4m3fn[12288,4096]{0,1} parameter(0)
    %transpose = f8e4m3fn[4096,12288]{0,1} transpose(%p0), dimensions={1,0}
    %p1 = f8e4m3fn[4096,16384]{0,1} parameter(1)
    %p2 = f32[] parameter(2)
    %constant_1 = f32[] constant(1)
    ROOT %cublas-gemm.1.0 = (bf16[12288,16384]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[4096,12288]{0,1} %transpose, f8e4m3fn[4096,16384]{0,1} %p1, f32[] %p2, f32[] %constant_1, f32[] %constant_1, f32[] %constant_1), custom_call_target="__cublas$lt$matmul$f8", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"gemm_backend_config":{"alpha_real":1,"alpha_imag":0,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["0"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"precision_config":{"operand_precision":["DEFAULT","DEFAULT"],"algorithm":"ALG_UNSET"},"epilogue":"DEFAULT","damax_output":false,"selected_algorithm":"7","lhs_stride":"50331648","rhs_stride":"67108864","grad_x":false,"grad_y":false},"force_earliest_schedule":false}
  }

This resulted in different numerics and upon checking the cublas runtime thunk - it processed the logical layout correctly and buffer assignment worked exactly the same.

We then had a unit test for testing out tranpose numerics as below

ENTRY main {
    %p0 = f8e4m3fn[12288,4096]{0,1} parameter(0)
    %b = f8e4m3fn[4096,12288]{1,0} reshape(%p0)
    %transpose = f8e4m3fn[12288,4096]{1,0} transpose(%b), dimensions={1,0}
    ROOT bitcast = f8e4m3fn[4096,12288]{0,1} reshape(%transpose)
  }
ENTRY main {
    %p0 = f8e4m3fn[12288,4096]{0,1} parameter(0)
    ROOT %transpose = f8e4m3fn[4096,12288]{0,1} transpose(%p0), dimensions={1,0}
  }

The numerical results of them were 99% different with relative errors > 1e-2.

Could you please help us understand why tranpose to different layout permutation would result in numerical difference? Is the default / non-default layout tranpose a known issue or are we making any unintentional assumptions / mistakes?

@elfiegg
Copy link
Contributor Author

elfiegg commented Sep 17, 2024

@kaixih @wenscarl

@elfiegg
Copy link
Contributor Author

elfiegg commented Sep 17, 2024

unit-test reproducer (that we also modified to test tranpose as the root of modules):

TEST_F(GpuCompilerTest, LayoutNormalizationRequiredForCublasF8) {
  auto cc = backend()
                .default_stream_executor()
                ->GetDeviceDescription()
                .cuda_compute_capability();
  if (!cc.IsAtLeastAmpere()) {
    GTEST_SKIP() << "Autotuning results have only been generated for Ampere "
                 << "and Hopper GPUs";
  }
  const absl::string_view good_hlo_string = R"( 
  HloModule test 

  ENTRY main {
    %p0 = f8e4m3fn[12288,4096]{0,1} parameter(0)
    %b = f8e4m3fn[4096,12288]{1,0} bitcast(%p0)
    %transpose = f8e4m3fn[12288,4096]{1,0} transpose(%b), dimensions={1,0}
    %p1 = f8e4m3fn[4096,16384]{0,1} parameter(1)
    %p2 = f32[] parameter(2)
    %constant_1 = f32[] constant(1)
    ROOT %cublas-gemm.1.0 = (bf16[12288,16384]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[12288,4096]{1,0} %transpose,  f8e4m3fn[4096,16384]{0,1} %p1,f32[] %p2, f32[]%constant_1, f32[]%constant_1, f32[]%constant_1), custom_call_target="__cublas$lt$matmul$f8", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"gemm_backend_config":{"alpha_real":1,"alpha_imag":0,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"precision_config":{"operand_precision":["DEFAULT","DEFAULT"],"algorithm":"ALG_UNSET"},"epilogue":"DEFAULT","damax_output":false,"selected_algorithm":"6","lhs_stride":"50331648","rhs_stride":"67108864","grad_x":false,"grad_y":false},"force_earliest_schedule":false}
  })";

  HloModuleConfig config;
  DebugOptions debug_options = GetDebugOptionsForTest();
  debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false);
  debug_options.set_xla_gpu_enable_triton_gemm(false);
  debug_options.set_xla_gpu_cublas_fallback(true);
  config.set_debug_options(debug_options);
  config.set_replica_count(1);
  config.set_num_partitions(1);

  TF_ASSERT_OK_AND_ASSIGN(
      std::unique_ptr<HloModule> good_module,
      ParseAndReturnVerifiedModule(good_hlo_string, config));

  const absl::string_view bad_hlo_string = R"(
  HloModule test 

  ENTRY main {
    %p0 = f8e4m3fn[12288,4096]{0,1} parameter(0)
    %transpose = f8e4m3fn[4096,12288]{0,1} transpose(%p0), dimensions={1,0}
    %p1 = f8e4m3fn[4096,16384]{0,1} parameter(1)
    %p2 = f32[] parameter(2)
    %constant_1 = f32[] constant(1)
    ROOT %cublas-gemm.1.0 = (bf16[12288,16384]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[4096,12288]{0,1} %transpose, f8e4m3fn[4096,16384]{0,1} %p1, f32[] %p2, f32[] %constant_1, f32[] %constant_1, f32[] %constant_1), custom_call_target="__cublas$lt$matmul$f8", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"gemm_backend_config":{"alpha_real":1,"alpha_imag":0,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["0"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"precision_config":{"operand_precision":["DEFAULT","DEFAULT"],"algorithm":"ALG_UNSET"},"epilogue":"DEFAULT","damax_output":false,"selected_algorithm":"7","lhs_stride":"50331648","rhs_stride":"67108864","grad_x":false,"grad_y":false},"force_earliest_schedule":false}
  })";

  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> bad_module,
                          ParseAndReturnVerifiedModule(bad_hlo_string, config));

  EXPECT_TRUE(RunAndCompareTwoModules(good_hlo_string, bad_hlo_string,
                                      ErrorSpec{1e-10, 1e-10}, false));
}

@sergachev
Copy link
Contributor

@akuegel is my understanding right that transposes should always use the default layout and that's normally ensured by the layout normalization? If so, should we try to detect the wrong ones at codegen or in the HLO verifier?

@mooskagh
Copy link
Member

It's indeed expected that cuBLAS gemms get the layout normalized, and if you feed the HLO to the optimization passes, it already fails (with slightly different error):

layout_assignment.cc:321] Check failed: !IsCublasGemm(*instruction) Gemm rewriting should run after layout assignment

We can add a check somewhere that ensures that layout that gets into the custom call is normalized, but that's just to ensure internal invariant, it's (at least in theory) not possible to get to this state from a pre-optimized HLO.

@sergachev
Copy link
Contributor

It's not about cuBLAS, it's about transpose alone, see the second reproducer in the first message.

@akuegel
Copy link
Member

akuegel commented Sep 19, 2024

@sergachev While Layout Normalization will make sure that transposes have the default layout, there could be passes later in the pipeline that create transposes with non-default layout. Note that anything that calls MakeTransposeHlo from hlo_creation_utils will most likely have a non-default layout, as that function infers a layout that will make the transpose a bitcast. This is something we want to avoid, so if you see any pass that runs after LayoutNormalization that calls MakeTransposeHlo, please file a bug or send a PR.

@elfiegg
Copy link
Contributor Author

elfiegg commented Sep 19, 2024

@akuegel it sounds to me, generally speaking we should ensure that layout normalization and its associated passes are called after all rewriters and op-changing passes, before codegen, to ensure they have accounted for all ops? As the layouts normalized by the pass might be a strict requirement

@akuegel
Copy link
Member

akuegel commented Sep 20, 2024

We already have HLO passes that rely on having only transposes with default layout. For example the one I added recently (TransposeDimensionGrouper) only works on transposes with default layout and will return an error otherwise. So just running the layout normalization once again at the end of the pipeline will not fix the issue. So the suggestion of @sergachev to make it part of the HloVerifier sounds better to me. It would need to be a verifier option that is off by default, but can be turned on in our pipeline after LayoutNormalization pass.

@elfiegg
Copy link
Contributor Author

elfiegg commented Sep 20, 2024

OK that sounds good! My original comment was more focused on other instructions involved in the layout normalization in a broader sense. Are all the instructions that the layout normalization pass standardizes considered a strict requirement? Or maybe transpose is a special case that we stumbled upon that would affect correctness

@akuegel
Copy link
Member

akuegel commented Sep 23, 2024

Once LayoutNormalization has run, it is quite unlikely that other passes will introduce ops that don't have the default layout. Normally the layout of new ops is derived from the ops surrounding it, so if all those ops have the default layout, the new ops will have the default layout as well. Transpose is special because of the MakeTransposeHlo() method, because that will choose a non-default layout. I believe it was a mistake to make that function assign a non-default layout, but that would probably be quite hard to change now.
And then, most of the code would still work with any layout, as LayoutNormalization is kind of new and the code was written to support any layout. Only newly added code might be relying on default layout.

@elfiegg
Copy link
Contributor Author

elfiegg commented Sep 23, 2024

The layout of new ops is indeed derived from the ops surrounding it, and the "bug" is due to some of ops don't have a chance to go through layoutnormalization pass: Triton first fuses FP8 GEMMs, but during layout normalization, the tranpose has not yet being inserted by GemmRewriter and ops within these fusions are not handled either. Then when the autotuner falls back to cublas, where the fused computations are inlined, cublas GemmRewriter might insert a non-default tranpose based on the context. In this situation, would you consider it a bug where layout normalization should also occur after inlining the computations, or should we better insert a non-default tranpose in the GemmRewriter?

@akuegel
Copy link
Member

akuegel commented Sep 24, 2024

Ideally we would insert a transpose with default layout in the GemmRewriter. If you have a transpose that preserves the non-default layout of its operand, it can be normalized to have a default layout by adding a bitcast transpose in front and after it. Unfortunately we still don't normalize Dots, which means we often have a bitcast operand of a dot with non-default layout, so if a transpose is inserted between the bitcast and the dot, it would have non-default layout as well.

@elfiegg
Copy link
Contributor Author

elfiegg commented Sep 24, 2024

OK, that case could you please also take a look at #17440 for any comment? @wenscarl had a fix for inserting default layout transpose in GemmRewriter

copybara-service bot pushed a commit that referenced this issue Oct 1, 2024
… rewrite

Imported from GitHub PR #17440

Related to #17276 and #16975.
This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue:
```
a = f8e4m3fn[x,y]{0,1} xxx
transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0)
custom-call(a,...)
```
to
```
a = f8e4m3fn[x,y]{0,1} xxx
bt = f8e4m3fn[y,x]{1,0} bitcast(a)
transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0)
bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1)
custom-call(bt.1,...)
```
Copybara import of the project:

--
237c032 by shuw <[email protected]>:

Improve TransposeMatrix

--
508cd69 by Shu Wang <[email protected]>:

Fix bug of permutation.
--
c55e8a9 by shuw <[email protected]>:

clang format

--
ad0a4ba by Shu Wang <[email protected]>:

Add unittest.
--
1d45b4d by Shu Wang <[email protected]>:

Remove uncessary space.
--
7837845 by Shu Wang <[email protected]>:

Update unittest.

--
b479c21 by shuw <[email protected]>:

Improve TransposeMatrix

Merging this change closes #17440

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17440 from wenscarl:fp8_regulate_transpose b479c21
PiperOrigin-RevId: 680886834
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Oct 1, 2024
… rewrite

Imported from GitHub PR openxla/xla#17440

Related to openxla/xla#17276 and openxla/xla#16975.
This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue:
```
a = f8e4m3fn[x,y]{0,1} xxx
transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0)
custom-call(a,...)
```
to
```
a = f8e4m3fn[x,y]{0,1} xxx
bt = f8e4m3fn[y,x]{1,0} bitcast(a)
transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0)
bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1)
custom-call(bt.1,...)
```
Copybara import of the project:

--
237c03240da3dce736d92c8273dc1f9d3be53af5 by shuw <[email protected]>:

Improve TransposeMatrix

--
508cd6928bbc20c1d87818eed4ee6190c6c9f691 by Shu Wang <[email protected]>:

Fix bug of permutation.
--
c55e8a9f64c8dac69907ccebce3b8109ddeb2c48 by shuw <[email protected]>:

clang format

--
ad0a4ba8054092dd79608865a823c1d432f81b21 by Shu Wang <[email protected]>:

Add unittest.
--
1d45b4d64347c64a9483fd26caf7d8598818b855 by Shu Wang <[email protected]>:

Remove uncessary space.
--
78378455e70e439e71da078c3099732a14292d7d by Shu Wang <[email protected]>:

Update unittest.

--
b479c2177672a0010ffba1630efdaec5ca4cee26 by shuw <[email protected]>:

Improve TransposeMatrix

Merging this change closes #17440

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17440 from wenscarl:fp8_regulate_transpose b479c2177672a0010ffba1630efdaec5ca4cee26
PiperOrigin-RevId: 680886834
copybara-service bot pushed a commit that referenced this issue Oct 1, 2024
… rewrite

Imported from GitHub PR #17440

Related to #17276 and #16975.
This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue:
```
a = f8e4m3fn[x,y]{0,1} xxx
transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0)
custom-call(a,...)
```
to
```
a = f8e4m3fn[x,y]{0,1} xxx
bt = f8e4m3fn[y,x]{1,0} bitcast(a)
transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0)
bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1)
custom-call(bt.1,...)
```
Copybara import of the project:

--
237c032 by shuw <[email protected]>:

Improve TransposeMatrix

--
508cd69 by Shu Wang <[email protected]>:

Fix bug of permutation.
--
c55e8a9 by shuw <[email protected]>:

clang format

--
ad0a4ba by Shu Wang <[email protected]>:

Add unittest.
--
1d45b4d by Shu Wang <[email protected]>:

Remove uncessary space.
--
7837845 by Shu Wang <[email protected]>:

Update unittest.

--
b479c21 by shuw <[email protected]>:

Improve TransposeMatrix

Merging this change closes #17440

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17440 from wenscarl:fp8_regulate_transpose b479c21
PiperOrigin-RevId: 680886834
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Oct 1, 2024
… rewrite

Imported from GitHub PR openxla/xla#17440

Related to openxla/xla#17276 and openxla/xla#16975.
This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue:
```
a = f8e4m3fn[x,y]{0,1} xxx
transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0)
custom-call(a,...)
```
to
```
a = f8e4m3fn[x,y]{0,1} xxx
bt = f8e4m3fn[y,x]{1,0} bitcast(a)
transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0)
bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1)
custom-call(bt.1,...)
```
Copybara import of the project:

--
237c03240da3dce736d92c8273dc1f9d3be53af5 by shuw <[email protected]>:

Improve TransposeMatrix

--
508cd6928bbc20c1d87818eed4ee6190c6c9f691 by Shu Wang <[email protected]>:

Fix bug of permutation.
--
c55e8a9f64c8dac69907ccebce3b8109ddeb2c48 by shuw <[email protected]>:

clang format

--
ad0a4ba8054092dd79608865a823c1d432f81b21 by Shu Wang <[email protected]>:

Add unittest.
--
1d45b4d64347c64a9483fd26caf7d8598818b855 by Shu Wang <[email protected]>:

Remove uncessary space.
--
78378455e70e439e71da078c3099732a14292d7d by Shu Wang <[email protected]>:

Update unittest.

--
b479c2177672a0010ffba1630efdaec5ca4cee26 by shuw <[email protected]>:

Improve TransposeMatrix

Merging this change closes #17440

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17440 from wenscarl:fp8_regulate_transpose b479c2177672a0010ffba1630efdaec5ca4cee26
PiperOrigin-RevId: 680886834
copybara-service bot pushed a commit that referenced this issue Oct 1, 2024
… rewrite

Imported from GitHub PR #17440

Related to #17276 and #16975.
This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue:
```
a = f8e4m3fn[x,y]{0,1} xxx
transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0)
custom-call(a,...)
```
to
```
a = f8e4m3fn[x,y]{0,1} xxx
bt = f8e4m3fn[y,x]{1,0} bitcast(a)
transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0)
bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1)
custom-call(bt.1,...)
```
Copybara import of the project:

--
237c032 by shuw <[email protected]>:

Improve TransposeMatrix

--
508cd69 by Shu Wang <[email protected]>:

Fix bug of permutation.
--
c55e8a9 by shuw <[email protected]>:

clang format

--
ad0a4ba by Shu Wang <[email protected]>:

Add unittest.
--
1d45b4d by Shu Wang <[email protected]>:

Remove uncessary space.
--
7837845 by Shu Wang <[email protected]>:

Update unittest.

--
b479c21 by shuw <[email protected]>:

Improve TransposeMatrix

Merging this change closes #17440

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17440 from wenscarl:fp8_regulate_transpose b479c21
PiperOrigin-RevId: 680886834
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Oct 1, 2024
… rewrite

Imported from GitHub PR openxla/xla#17440

Related to openxla/xla#17276 and openxla/xla#16975.
This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue:
```
a = f8e4m3fn[x,y]{0,1} xxx
transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0)
custom-call(a,...)
```
to
```
a = f8e4m3fn[x,y]{0,1} xxx
bt = f8e4m3fn[y,x]{1,0} bitcast(a)
transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0)
bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1)
custom-call(bt.1,...)
```
Copybara import of the project:

--
237c03240da3dce736d92c8273dc1f9d3be53af5 by shuw <[email protected]>:

Improve TransposeMatrix

--
508cd6928bbc20c1d87818eed4ee6190c6c9f691 by Shu Wang <[email protected]>:

Fix bug of permutation.
--
c55e8a9f64c8dac69907ccebce3b8109ddeb2c48 by shuw <[email protected]>:

clang format

--
ad0a4ba8054092dd79608865a823c1d432f81b21 by Shu Wang <[email protected]>:

Add unittest.
--
1d45b4d64347c64a9483fd26caf7d8598818b855 by Shu Wang <[email protected]>:

Remove uncessary space.
--
78378455e70e439e71da078c3099732a14292d7d by Shu Wang <[email protected]>:

Update unittest.

--
b479c2177672a0010ffba1630efdaec5ca4cee26 by shuw <[email protected]>:

Improve TransposeMatrix

Merging this change closes #17440

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17440 from wenscarl:fp8_regulate_transpose b479c2177672a0010ffba1630efdaec5ca4cee26
PiperOrigin-RevId: 680886834
copybara-service bot pushed a commit that referenced this issue Oct 1, 2024
… rewrite

Imported from GitHub PR #17440

Related to #17276 and #16975.
This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue:
```
a = f8e4m3fn[x,y]{0,1} xxx
transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0)
custom-call(a,...)
```
to
```
a = f8e4m3fn[x,y]{0,1} xxx
bt = f8e4m3fn[y,x]{1,0} bitcast(a)
transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0)
bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1)
custom-call(bt.1,...)
```
Copybara import of the project:

--
237c032 by shuw <[email protected]>:

Improve TransposeMatrix

--
508cd69 by Shu Wang <[email protected]>:

Fix bug of permutation.
--
c55e8a9 by shuw <[email protected]>:

clang format

--
ad0a4ba by Shu Wang <[email protected]>:

Add unittest.
--
1d45b4d by Shu Wang <[email protected]>:

Remove uncessary space.
--
7837845 by Shu Wang <[email protected]>:

Update unittest.

--
b479c21 by shuw <[email protected]>:

Improve TransposeMatrix

Merging this change closes #17440

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17440 from wenscarl:fp8_regulate_transpose b479c21
PiperOrigin-RevId: 680886834
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Oct 1, 2024
… rewrite

Imported from GitHub PR openxla/xla#17440

Related to openxla/xla#17276 and openxla/xla#16975.
This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue:
```
a = f8e4m3fn[x,y]{0,1} xxx
transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0)
custom-call(a,...)
```
to
```
a = f8e4m3fn[x,y]{0,1} xxx
bt = f8e4m3fn[y,x]{1,0} bitcast(a)
transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0)
bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1)
custom-call(bt.1,...)
```
Copybara import of the project:

--
237c03240da3dce736d92c8273dc1f9d3be53af5 by shuw <[email protected]>:

Improve TransposeMatrix

--
508cd6928bbc20c1d87818eed4ee6190c6c9f691 by Shu Wang <[email protected]>:

Fix bug of permutation.
--
c55e8a9f64c8dac69907ccebce3b8109ddeb2c48 by shuw <[email protected]>:

clang format

--
ad0a4ba8054092dd79608865a823c1d432f81b21 by Shu Wang <[email protected]>:

Add unittest.
--
1d45b4d64347c64a9483fd26caf7d8598818b855 by Shu Wang <[email protected]>:

Remove uncessary space.
--
78378455e70e439e71da078c3099732a14292d7d by Shu Wang <[email protected]>:

Update unittest.

--
b479c2177672a0010ffba1630efdaec5ca4cee26 by shuw <[email protected]>:

Improve TransposeMatrix

Merging this change closes #17440

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17440 from wenscarl:fp8_regulate_transpose b479c2177672a0010ffba1630efdaec5ca4cee26
PiperOrigin-RevId: 680886834
copybara-service bot pushed a commit that referenced this issue Oct 1, 2024
… rewrite

Imported from GitHub PR #17440

Related to #17276 and #16975.
This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue:
```
a = f8e4m3fn[x,y]{0,1} xxx
transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0)
custom-call(a,...)
```
to
```
a = f8e4m3fn[x,y]{0,1} xxx
bt = f8e4m3fn[y,x]{1,0} bitcast(a)
transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0)
bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1)
custom-call(bt.1,...)
```
Copybara import of the project:

--
237c032 by shuw <[email protected]>:

Improve TransposeMatrix

--
508cd69 by Shu Wang <[email protected]>:

Fix bug of permutation.
--
c55e8a9 by shuw <[email protected]>:

clang format

--
ad0a4ba by Shu Wang <[email protected]>:

Add unittest.
--
1d45b4d by Shu Wang <[email protected]>:

Remove uncessary space.
--
7837845 by Shu Wang <[email protected]>:

Update unittest.

--
b479c21 by shuw <[email protected]>:

Improve TransposeMatrix

Merging this change closes #17440

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17440 from wenscarl:fp8_regulate_transpose b479c21
PiperOrigin-RevId: 680886834
copybara-service bot pushed a commit that referenced this issue Oct 1, 2024
… rewrite

Imported from GitHub PR #17440

Related to #17276 and #16975.
This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue:
```
a = f8e4m3fn[x,y]{0,1} xxx
transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0)
custom-call(a,...)
```
to
```
a = f8e4m3fn[x,y]{0,1} xxx
bt = f8e4m3fn[y,x]{1,0} bitcast(a)
transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0)
bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1)
custom-call(bt.1,...)
```
Copybara import of the project:

--
237c032 by shuw <[email protected]>:

Improve TransposeMatrix

--
508cd69 by Shu Wang <[email protected]>:

Fix bug of permutation.
--
c55e8a9 by shuw <[email protected]>:

clang format

--
ad0a4ba by Shu Wang <[email protected]>:

Add unittest.
--
1d45b4d by Shu Wang <[email protected]>:

Remove uncessary space.
--
7837845 by Shu Wang <[email protected]>:

Update unittest.

--
b479c21 by shuw <[email protected]>:

Improve TransposeMatrix

Merging this change closes #17440

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17440 from wenscarl:fp8_regulate_transpose 824ac54
PiperOrigin-RevId: 680886834
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Oct 1, 2024
… rewrite

Imported from GitHub PR openxla/xla#17440

Related to openxla/xla#17276 and openxla/xla#16975.
This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue:
```
a = f8e4m3fn[x,y]{0,1} xxx
transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0)
custom-call(a,...)
```
to
```
a = f8e4m3fn[x,y]{0,1} xxx
bt = f8e4m3fn[y,x]{1,0} bitcast(a)
transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0)
bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1)
custom-call(bt.1,...)
```
Copybara import of the project:

--
237c03240da3dce736d92c8273dc1f9d3be53af5 by shuw <[email protected]>:

Improve TransposeMatrix

--
508cd6928bbc20c1d87818eed4ee6190c6c9f691 by Shu Wang <[email protected]>:

Fix bug of permutation.
--
c55e8a9f64c8dac69907ccebce3b8109ddeb2c48 by shuw <[email protected]>:

clang format

--
ad0a4ba8054092dd79608865a823c1d432f81b21 by Shu Wang <[email protected]>:

Add unittest.
--
1d45b4d64347c64a9483fd26caf7d8598818b855 by Shu Wang <[email protected]>:

Remove uncessary space.
--
78378455e70e439e71da078c3099732a14292d7d by Shu Wang <[email protected]>:

Update unittest.

--
b479c2177672a0010ffba1630efdaec5ca4cee26 by shuw <[email protected]>:

Improve TransposeMatrix

Merging this change closes #17440

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17440 from wenscarl:fp8_regulate_transpose 824ac5425f1529326086c86f1cc7f31eee1fee9b
PiperOrigin-RevId: 680886834
copybara-service bot pushed a commit that referenced this issue Oct 1, 2024
… rewrite

Imported from GitHub PR #17440

Related to #17276 and #16975.
This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue:
```
a = f8e4m3fn[x,y]{0,1} xxx
transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0)
custom-call(a,...)
```
to
```
a = f8e4m3fn[x,y]{0,1} xxx
bt = f8e4m3fn[y,x]{1,0} bitcast(a)
transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0)
bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1)
custom-call(bt.1,...)
```
Copybara import of the project:

--
237c032 by shuw <[email protected]>:

Improve TransposeMatrix

--
508cd69 by Shu Wang <[email protected]>:

Fix bug of permutation.
--
c55e8a9 by shuw <[email protected]>:

clang format

--
ad0a4ba by Shu Wang <[email protected]>:

Add unittest.
--
1d45b4d by Shu Wang <[email protected]>:

Remove uncessary space.
--
7837845 by Shu Wang <[email protected]>:

Update unittest.

--
b479c21 by shuw <[email protected]>:

Improve TransposeMatrix

Merging this change closes #17440

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17440 from wenscarl:fp8_regulate_transpose 824ac54
PiperOrigin-RevId: 680886834
copybara-service bot pushed a commit that referenced this issue Oct 1, 2024
… rewrite

Imported from GitHub PR #17440

Related to #17276 and #16975.
This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue:
```
a = f8e4m3fn[x,y]{0,1} xxx
transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0)
custom-call(a,...)
```
to
```
a = f8e4m3fn[x,y]{0,1} xxx
bt = f8e4m3fn[y,x]{1,0} bitcast(a)
transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0)
bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1)
custom-call(bt.1,...)
```
Copybara import of the project:

--
237c032 by shuw <[email protected]>:

Improve TransposeMatrix

--
508cd69 by Shu Wang <[email protected]>:

Fix bug of permutation.
--
c55e8a9 by shuw <[email protected]>:

clang format

--
ad0a4ba by Shu Wang <[email protected]>:

Add unittest.
--
1d45b4d by Shu Wang <[email protected]>:

Remove uncessary space.
--
7837845 by Shu Wang <[email protected]>:

Update unittest.

--
b479c21 by shuw <[email protected]>:

Improve TransposeMatrix

Merging this change closes #17440

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17440 from wenscarl:fp8_regulate_transpose 824ac54
PiperOrigin-RevId: 680886834
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants