Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backend] Implement scaled_dot(mxfp4, fp8) #4904

Merged
merged 7 commits into from
Oct 16, 2024
Merged

Conversation

lezcano
Copy link
Contributor

@lezcano lezcano commented Oct 14, 2024

This PR includes #4891 and #4895. I will rebase once those have landed.

It includes a number of hacks to work around bugs in DotOperandEncodingAttr. All these are marked as FIXME [Dot LL] to be easy to grep for. @Jokeren is working on a comprehensive revamp of DotOperandEncodingAttr which will get rid of all these. #4895 is the first step in this direction.

@lezcano lezcano changed the title mxfp snd [Backend] Implement scaled_dot(mxfp4, fp8) Oct 14, 2024
@lezcano lezcano marked this pull request as draft October 14, 2024 17:51
@lezcano lezcano force-pushed the mxfp_snd branch 4 times, most recently from c15d411 to 104200d Compare October 15, 2024 14:23
@lezcano lezcano marked this pull request as ready for review October 15, 2024 14:44
@lezcano lezcano force-pushed the mxfp_snd branch 2 times, most recently from 20a64b1 to 33fceb2 Compare October 15, 2024 16:44
@@ -39,7 +39,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>,
// CHECK: offset = 0, size = 4608
%a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT>
%b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr<f16>, #BL>
// CHECK-NEXT: offset = 0, size = 4224
// CHECK-NEXT: offset = 0, size = 4352
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nb. These changes are coming from the change in lib/Analysis/Allocation.cpp

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's OK this path was never tested anyway. It will be tested in my next PR.

// This should be getElemOrder, but we don't have such a method
// TODO Implement getElemOrder and make sure it's consistent with
// getContigPerThread
auto inOrd = gpu::getThreadOrder(srcLayout);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we assume getElemOrder == getOrder

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getThreadOrder is same as getOrder except for AMD's AMDMfmaEncodingAttr. I haven't taken a deep investigation.
pin @zhanglx13 for expertise maybe

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See that I changed the definition of getThreadOrder in this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be specific I was referring to:

SmallVector<unsigned> AMDMfmaEncodingAttr::getThreadOrder() const {
  auto order = ::getOrder(*this);
  if (getIsTransposed())
    std::swap(order[0], order[1]);
  return order;
}

I'm not sure if we should use getOrder or getThreadOrder for this encoding

lib/Dialect/TritonGPU/IR/Dialect.cpp Show resolved Hide resolved
@@ -39,7 +39,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>,
// CHECK: offset = 0, size = 4608
%a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT>
%b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr<f16>, #BL>
// CHECK-NEXT: offset = 0, size = 4224
// CHECK-NEXT: offset = 0, size = 4352
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's OK this path was never tested anyway. It will be tested in my next PR.

auto ha = getValuesFromDotOperandLayoutStruct(
typeConverter, loc, rewriter, loadedA, repBatch, repM, repK, aTensorTy);

// FIXME [Dot LL]
// max(repN / 2, 1) is wrong for repN = 1!
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate on // max(repN / 2, 1) is wrong for repN = 1!?
Why repN=1 is wrong?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are taking this max(repN / 2, 1) here, and then in the loop inside getValuesFromDotOperandLayoutStruct we are packing 4 elements at a time. Rather than that, the correct implementation packs 2 elements inside the function for opIdx=1 and iterates repN times.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it

Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall although I didn't look in details at the LL TODOs.
Just added few minor comments

Comment on lines 260 to 261
// FIXME: mma should just return getOrderForDotOperand(0, order.size(),
// kMajor=false)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also confused by this comment.

Copy link
Contributor Author

@lezcano lezcano Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I just meant that the logic in mma is probably wrong and we just want this function to return what I wrote there. The point here is that, in terms of order, the mma layout is the same as the DotOperandEncoding(opIdx=0)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had another go at the comment. Third's a charm

Comment on lines +271 to +272
order = getOrderForDotOperand(dotOpLayout.getOpIdx(), order.size(),
/*kMajor*/ false);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is kMajor always false here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is getting the warp order but not the element order. So m is the fastest changing dimension in opIdx=0. I think confusion may arise from the variable name kMajor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a suggestion for improvement though. Maybe just add some additional comments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, similarly to in wgmma, we want the warps have the exterior dimension (i.e. not K) as their fastest running dimension.

vType.getShape(), vType.getElementType(), newVEncoding);
return rewriter.create<ConvertLayoutOp>(v.getLoc(), newVType, v);
} else {
auto newVEncoding = DotOperandEncodingAttr::get(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: assert that this is a fp8 type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, although it's a bit redundant, as we are already asserting this at the beginning of the function and in semantics.py.

Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lezcano lezcano merged commit 9e90089 into triton-lang:main Oct 16, 2024
7 checks passed
@lezcano lezcano deleted the mxfp_snd branch October 16, 2024 15:21
alexsamardzic pushed a commit to alexsamardzic/triton that referenced this pull request Oct 16, 2024
This PR includes triton-lang#4891 and
triton-lang#4895. I will rebase once
those have landed.

It includes a number of hacks to work around bugs in
`DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be
easy to grep for. @Jokeren is working on a comprehensive revamp of
`DotOperandEncodingAttr` which will get rid of all these.
triton-lang#4895 is the first step in
this direction.
Luosuu pushed a commit to Luosuu/triton that referenced this pull request Nov 13, 2024
This PR includes triton-lang#4891 and
triton-lang#4895. I will rebase once
those have landed.

It includes a number of hacks to work around bugs in
`DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be
easy to grep for. @Jokeren is working on a comprehensive revamp of
`DotOperandEncodingAttr` which will get rid of all these.
triton-lang#4895 is the first step in
this direction.
guacamoleo pushed a commit to guacamoleo/triton that referenced this pull request Nov 14, 2024
This PR includes triton-lang#4891 and
triton-lang#4895. I will rebase once
those have landed.

It includes a number of hacks to work around bugs in
`DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be
easy to grep for. @Jokeren is working on a comprehensive revamp of
`DotOperandEncodingAttr` which will get rid of all these.
triton-lang#4895 is the first step in
this direction.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants