Skip to content

Commit

Permalink
PR #17440: Inject desired pattern for handling Transpose for fp8 gemm…
Browse files Browse the repository at this point in the history
… rewrite

Imported from GitHub PR #17440

Related to #17276 and #16975.
This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue:
```
a = f8e4m3fn[x,y]{0,1} xxx
transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0)
custom-call(a,...)
```
to
```
a = f8e4m3fn[x,y]{0,1} xxx
bt = f8e4m3fn[y,x]{1,0} bitcast(a)
transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0)
bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1)
custom-call(bt.1,...)
```
Copybara import of the project:

--
237c032 by shuw <[email protected]>:

Improve TransposeMatrix

--
508cd69 by Shu Wang <[email protected]>:

Fix bug of permutation.
--
c55e8a9 by shuw <[email protected]>:

clang format

--
ad0a4ba by Shu Wang <[email protected]>:

Add unittest.
--
1d45b4d by Shu Wang <[email protected]>:

Remove uncessary space.
--
7837845 by Shu Wang <[email protected]>:

Update unittest.

--
b479c21 by shuw <[email protected]>:

Improve TransposeMatrix

--
b633184 by Shu Wang <[email protected]>:

Update unittest shape and BUILD file.

Merging this change closes #17440

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17440 from wenscarl:fp8_regulate_transpose b633184
PiperOrigin-RevId: 680886834
  • Loading branch information
wenscarl authored and Google-ML-Automation committed Oct 2, 2024
1 parent 93eb146 commit 8d04204
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 8 deletions.
1 change: 1 addition & 0 deletions xla/service/gpu/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1671,6 +1671,7 @@ cc_library(
deps = [
"//xla:literal",
"//xla:literal_util",
"//xla:permutation_util",
"//xla:shape_util",
"//xla:status_macros",
"//xla:types",
Expand Down
52 changes: 44 additions & 8 deletions xla/service/gpu/transforms/gemm_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/layout.h"
#include "xla/layout_util.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/permutation_util.h"
#include "xla/primitive_util.h"
#include "xla/service/algorithm_util.h"
#include "xla/service/gpu/backend_configs.pb.h"
Expand Down Expand Up @@ -362,27 +364,61 @@ std::optional<MatchedFp8Param> MatchFp8Param(HloInstruction *instr) {
// dimension. Keeps the layout the same.
HloInstruction *TransposeMatrix(HloInstruction *instr, int64_t contracting_dim,
absl::Span<const int64_t> batch_dims) {
auto input_shape = instr->shape();
// Identify the dimensional order which describes a transpose of the
// contracting and non-contracting dimensions of the GEMM.
std::vector<int64_t> permutation(instr->shape().dimensions_size(), -1);
std::vector<int64_t> permutation(input_shape.dimensions_size(), -1);
// Discard the batch dimensions.
for (int64_t batch_dim : batch_dims) {
permutation[batch_dim] = batch_dim;
}
// Identify the non-contracting dimension.
int non_contracting_dim;
for (int i = 0; i < instr->shape().dimensions_size(); ++i) {
for (int i = 0; i < input_shape.dimensions_size(); ++i) {
if (permutation[i] == -1 && contracting_dim != i) {
non_contracting_dim = i;
}
}
permutation[non_contracting_dim] = contracting_dim;
permutation[contracting_dim] = non_contracting_dim;

Shape new_shape = ShapeUtil::PermuteDimensions(permutation, instr->shape());
*new_shape.mutable_layout() = instr->shape().layout();
return instr->AddInstruction(
HloInstruction::CreateTranspose(new_shape, instr, permutation));
if (Layout::Equal()(input_shape.layout(),
LayoutUtil::GetDefaultLayoutForShape(input_shape))) {
permutation[non_contracting_dim] = contracting_dim;
permutation[contracting_dim] = non_contracting_dim;

Shape new_shape = ShapeUtil::PermuteDimensions(permutation, input_shape);
*new_shape.mutable_layout() = input_shape.layout();

return instr->AddInstruction(
HloInstruction::CreateTranspose(new_shape, instr, permutation));
}

Shape normalized_input_shape =
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
input_shape);
auto a0 = MakeBitcastHlo(instr, normalized_input_shape);

std::vector<int64_t> layout_permuation(
input_shape.layout().minor_to_major().begin(),
input_shape.layout().minor_to_major().end());
absl::c_reverse(layout_permuation);
auto inv_perm = InversePermutation(layout_permuation);

int new_contracting_dim = inv_perm[contracting_dim];
int new_non_contracting_dim = inv_perm[non_contracting_dim];
absl::c_iota(permutation, 0);
std::swap(permutation[new_contracting_dim],
permutation[new_non_contracting_dim]);

Shape transpose_shape =
ShapeUtil::PermuteDimensions(permutation, a0->shape());
*transpose_shape.mutable_layout() = a0->shape().layout();

HloInstruction *normalized_transpose = instr->AddInstruction(
HloInstruction::CreateTranspose(transpose_shape, a0, permutation));

Shape final_shape = ShapeUtil::PermuteDimensions(inv_perm, transpose_shape);
*final_shape.mutable_layout() = input_shape.layout();
return MakeBitcastHlo(normalized_transpose, final_shape);
}

// If the bias is a sequence of ops that depend only on broadcasts of
Expand Down
52 changes: 52 additions & 0 deletions xla/service/gpu/transforms/gemm_rewriter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5032,6 +5032,58 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDMatrixBiasF8) {
)");
}

TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDColMajorLhsF8) {
const char* hlo_text = R"(
HloModule test
ENTRY test {
x = <<F8E4M3>>[2,64,32]{1,2,0} parameter(0)
y = <<F8E4M3>>[2,32,16]{2,1,0} parameter(1)
x_scale = f32[] parameter(2)
y_scale = f32[] parameter(3)
dq_scale = f32[] multiply(x_scale, y_scale)
dq_scale_bcast = f32[2,64,16] broadcast(dq_scale), dimensions={}
out.0 = f32[2,64,16] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
ROOT out = f32[2,64,16] multiply(out.0, dq_scale_bcast)
}
)";

CheckFp8IfSupported(hlo_text);
RunAndFilecheckHloRewrite(
hlo_text,
GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
R"(
; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[2,64,32], {{.*}}: <<F8E4M3>>[2,32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[2,64,16] {
; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[2,64,32]{1,2,0} parameter(0)
; CHECK-NEXT: [[P0_BT:%[^ ]+]] = <<F8E4M3>>[2,32,64]{2,1,0} bitcast([[P0]])
; CHECK-NEXT: [[P0_TR:%[^ ]+]] = <<F8E4M3>>[2,64,32]{2,1,0} transpose([[P0_BT]]), dimensions={0,2,1}
; CHECK-NEXT: [[P0_BT1:%[^ ]+]] = <<F8E4M3>>[2,32,64]{1,2,0} bitcast([[P0_TR]])
; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[2,32,16]{2,1,0} parameter(1)
; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[2,16,32]{2,1,0} transpose([[P1]]), dimensions={0,2,1}
; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
; CHECK-NEXT: [[DQ:%[^ ]+]] = f32[] multiply([[P2]], [[P3]])
; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,64,16]{2,1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BT1]], [[P1_TRANSPOSE]], [[DQ]], [[C1]]),
; CHECK: custom_call_target="__cublas$lt$matmul$f8",
; CHECK: backend_config={
; CHECK-DAG: "alpha_real":1
; CHECK-DAG: "alpha_imag":0
; CHECK-DAG: "beta":0
; CHECK-DAG: "dot_dimension_numbers":{
; CHECK-DAG: "lhs_contracting_dimensions":["1"]
; CHECK-DAG: "rhs_contracting_dimensions":["2"]
; CHECK-DAG: "lhs_batch_dimensions":["0"]
; CHECK-DAG: "rhs_batch_dimensions":["0"]
; CHECK-DAG: }
; CHECK-DAG: "precision_config":{
; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
; CHECK-DAG: }
; CHECK-DAG: "epilogue":"DEFAULT"
; CHECK: }
)");
}

TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8) {
const char* hlo_text = R"(
HloModule test
Expand Down

0 comments on commit 8d04204

Please sign in to comment.