Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inject desired pattern for handling Transpose for fp8 gemm rewrite #17440

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 43 additions & 8 deletions xla/service/gpu/transforms/gemm_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ limitations under the License.
#include "xla/layout.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/permutation_util.h"
wenscarl marked this conversation as resolved.
Show resolved Hide resolved
#include "xla/primitive_util.h"
#include "xla/service/algorithm_util.h"
#include "xla/service/gpu/backend_configs.pb.h"
Expand Down Expand Up @@ -362,27 +363,61 @@ std::optional<MatchedFp8Param> MatchFp8Param(HloInstruction *instr) {
// dimension. Keeps the layout the same.
HloInstruction *TransposeMatrix(HloInstruction *instr, int64_t contracting_dim,
absl::Span<const int64_t> batch_dims) {
auto input_shape = instr->shape();
// Identify the dimensional order which describes a transpose of the
// contracting and non-contracting dimensions of the GEMM.
std::vector<int64_t> permutation(instr->shape().dimensions_size(), -1);
std::vector<int64_t> permutation(input_shape.dimensions_size(), -1);
// Discard the batch dimensions.
wenscarl marked this conversation as resolved.
Show resolved Hide resolved
for (int64_t batch_dim : batch_dims) {
permutation[batch_dim] = batch_dim;
}
// Identify the non-contracting dimension.
int non_contracting_dim;
for (int i = 0; i < instr->shape().dimensions_size(); ++i) {
for (int i = 0; i < input_shape.dimensions_size(); ++i) {
if (permutation[i] == -1 && contracting_dim != i) {
non_contracting_dim = i;
}
}
permutation[non_contracting_dim] = contracting_dim;
permutation[contracting_dim] = non_contracting_dim;

Shape new_shape = ShapeUtil::PermuteDimensions(permutation, instr->shape());
*new_shape.mutable_layout() = instr->shape().layout();
return instr->AddInstruction(
HloInstruction::CreateTranspose(new_shape, instr, permutation));
if (Layout::Equal()(input_shape.layout(),
LayoutUtil::GetDefaultLayoutForShape(input_shape))) {
permutation[non_contracting_dim] = contracting_dim;
permutation[contracting_dim] = non_contracting_dim;

Shape new_shape = ShapeUtil::PermuteDimensions(permutation, input_shape);
*new_shape.mutable_layout() = input_shape.layout();

return instr->AddInstruction(
HloInstruction::CreateTranspose(new_shape, instr, permutation));
}

Shape normalized_input_shape =
wenscarl marked this conversation as resolved.
Show resolved Hide resolved
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
input_shape);
auto a0 = MakeBitcastHlo(instr, normalized_input_shape);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be conceptually simpler to insert a copy before the transpose to change the layout of the input? (This assumes that the copy -> transpose sequence is optimized by another pass which I haven't verified.)

Also, can we pick a more descriptive variable name here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pass runs after layout normalization which turns copies into bitcast + transpose, so it should not produce any Copy ops that change the layout (otherwise we would have to run layout normalization again).


std::vector<int64_t> layout_permuation(
input_shape.layout().minor_to_major().begin(),
input_shape.layout().minor_to_major().end());
absl::c_reverse(layout_permuation);
auto inv_perm = InversePermutation(layout_permuation);

int new_contracting_dim = inv_perm[contracting_dim];
int new_non_contracting_dim = inv_perm[non_contracting_dim];
absl::c_iota(permutation, 0);
std::swap(permutation[new_contracting_dim],
permutation[new_non_contracting_dim]);

Shape transpose_shape =
wenscarl marked this conversation as resolved.
Show resolved Hide resolved
ShapeUtil::PermuteDimensions(permutation, a0->shape());
*transpose_shape.mutable_layout() = a0->shape().layout();

HloInstruction *normalized_transpose = instr->AddInstruction(
HloInstruction::CreateTranspose(transpose_shape, a0, permutation));

Shape final_shape = ShapeUtil::PermuteDimensions(inv_perm, transpose_shape);
*final_shape.mutable_layout() = input_shape.layout();
return MakeBitcastHlo(normalized_transpose, final_shape);
}

// If the bias is a sequence of ops that depend only on broadcasts of
Expand Down
52 changes: 52 additions & 0 deletions xla/service/gpu/transforms/gemm_rewriter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5032,6 +5032,58 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDMatrixBiasF8) {
)");
}

TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDColMajorLhsF8) {
wenscarl marked this conversation as resolved.
Show resolved Hide resolved
const char* hlo_text = R"(
wenscarl marked this conversation as resolved.
Show resolved Hide resolved
HloModule test
ENTRY test {
x = <<F8E4M3>>[2,16,32]{1,0,2} parameter(0)
y = <<F8E4M3>>[2,32,16]{2,1,0} parameter(1)
x_scale = f32[] parameter(2)
y_scale = f32[] parameter(3)
dq_scale = f32[] multiply(x_scale, y_scale)
dq_scale_bcast = f32[2,16,16] broadcast(dq_scale), dimensions={}
out.0 = f32[2,16,16] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
ROOT out = f32[2,16,16] multiply(out.0, dq_scale_bcast)
}
)";

CheckFp8IfSupported(hlo_text);
RunAndFilecheckHloRewrite(
hlo_text,
GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
R"(
; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[2,16,32], {{.*}}: <<F8E4M3>>[2,32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[2,16,16] {
; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[2,16,32]{1,0,2} parameter(0)
; CHECK-NEXT: [[P0_BT:%[^ ]+]] = <<F8E4M3>>[32,2,16]{2,1,0} bitcast([[P0]])
; CHECK-NEXT: [[P0_TR:%[^ ]+]] = <<F8E4M3>>[16,2,32]{2,1,0} transpose([[P0_BT]]), dimensions={2,1,0}
; CHECK-NEXT: [[P0_BT1:%[^ ]+]] = <<F8E4M3>>[2,32,16]{1,0,2} bitcast([[P0_TR]])
; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[2,32,16]{2,1,0} parameter(1)
; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[2,16,32]{2,1,0} transpose([[P1]]), dimensions={0,2,1}
; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
; CHECK-NEXT: [[DQ:%[^ ]+]] = f32[] multiply([[P2]], [[P3]])
; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,16,16]{2,1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BT1]], [[P1_TRANSPOSE]], [[DQ]], [[C1]], [[C1]], /*index=5*/[[C1]]),
; CHECK: custom_call_target="__cublas$lt$matmul$f8",
; CHECK: backend_config={
; CHECK-DAG: "alpha_real":1
; CHECK-DAG: "alpha_imag":0
; CHECK-DAG: "beta":0
; CHECK-DAG: "dot_dimension_numbers":{
; CHECK-DAG: "lhs_contracting_dimensions":["1"]
; CHECK-DAG: "rhs_contracting_dimensions":["2"]
; CHECK-DAG: "lhs_batch_dimensions":["0"]
; CHECK-DAG: "rhs_batch_dimensions":["0"]
; CHECK-DAG: }
; CHECK-DAG: "precision_config":{
; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
; CHECK-DAG: }
; CHECK-DAG: "epilogue":"DEFAULT"
; CHECK: }
)");
}

TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8) {
const char* hlo_text = R"(
HloModule test
Expand Down
Loading