-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Backend] Implement scaled_dot(mxfp4, fp8)
#4904
Conversation
c15d411
to
104200d
Compare
20a64b1
to
33fceb2
Compare
@@ -39,7 +39,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, | |||
// CHECK: offset = 0, size = 4608 | |||
%a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> | |||
%b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr<f16>, #BL> | |||
// CHECK-NEXT: offset = 0, size = 4224 | |||
// CHECK-NEXT: offset = 0, size = 4352 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nb. These changes are coming from the change in lib/Analysis/Allocation.cpp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's OK this path was never tested anyway. It will be tested in my next PR.
// This should be getElemOrder, but we don't have such a method | ||
// TODO Implement getElemOrder and make sure it's consistent with | ||
// getContigPerThread | ||
auto inOrd = gpu::getThreadOrder(srcLayout); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we assume getElemOrder == getOrder
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getThreadOrder
is same as getOrder
except for AMD's AMDMfmaEncodingAttr
. I haven't taken a deep investigation.
pin @zhanglx13 for expertise maybe
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See that I changed the definition of getThreadOrder
in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be specific I was referring to:
SmallVector<unsigned> AMDMfmaEncodingAttr::getThreadOrder() const {
auto order = ::getOrder(*this);
if (getIsTransposed())
std::swap(order[0], order[1]);
return order;
}
I'm not sure if we should use getOrder
or getThreadOrder
for this encoding
@@ -39,7 +39,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, | |||
// CHECK: offset = 0, size = 4608 | |||
%a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> | |||
%b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr<f16>, #BL> | |||
// CHECK-NEXT: offset = 0, size = 4224 | |||
// CHECK-NEXT: offset = 0, size = 4352 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's OK this path was never tested anyway. It will be tested in my next PR.
auto ha = getValuesFromDotOperandLayoutStruct( | ||
typeConverter, loc, rewriter, loadedA, repBatch, repM, repK, aTensorTy); | ||
|
||
// FIXME [Dot LL] | ||
// max(repN / 2, 1) is wrong for repN = 1! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate on // max(repN / 2, 1) is wrong for repN = 1!
?
Why repN=1
is wrong?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are taking this max(repN / 2, 1)
here, and then in the loop inside getValuesFromDotOperandLayoutStruct
we are packing 4 elements at a time. Rather than that, the correct implementation packs 2 elements inside the function for opIdx=1
and iterates repN
times.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it
This is a tentative PR to check how much breaks if we fix this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall although I didn't look in details at the LL TODOs.
Just added few minor comments
lib/Dialect/TritonGPU/IR/Dialect.cpp
Outdated
// FIXME: mma should just return getOrderForDotOperand(0, order.size(), | ||
// kMajor=false) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm also confused by this comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I just meant that the logic in mma is probably wrong and we just want this function to return what I wrote there. The point here is that, in terms of order, the mma layout is the same as the DotOperandEncoding(opIdx=0)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had another go at the comment. Third's a charm
order = getOrderForDotOperand(dotOpLayout.getOpIdx(), order.size(), | ||
/*kMajor*/ false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is kMajor always false here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is getting the warp order but not the element order. So m is the fastest changing dimension in opIdx=0. I think confusion may arise from the variable name kMajor
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have a suggestion for improvement though. Maybe just add some additional comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, similarly to in wgmma, we want the warps have the exterior dimension (i.e. not K) as their fastest running dimension.
vType.getShape(), vType.getElementType(), newVEncoding); | ||
return rewriter.create<ConvertLayoutOp>(v.getLoc(), newVType, v); | ||
} else { | ||
auto newVEncoding = DotOperandEncodingAttr::get( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: assert that this is a fp8 type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, although it's a bit redundant, as we are already asserting this at the beginning of the function and in semantics.py
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This PR includes triton-lang#4891 and triton-lang#4895. I will rebase once those have landed. It includes a number of hacks to work around bugs in `DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be easy to grep for. @Jokeren is working on a comprehensive revamp of `DotOperandEncodingAttr` which will get rid of all these. triton-lang#4895 is the first step in this direction.
This PR includes triton-lang#4891 and triton-lang#4895. I will rebase once those have landed. It includes a number of hacks to work around bugs in `DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be easy to grep for. @Jokeren is working on a comprehensive revamp of `DotOperandEncodingAttr` which will get rid of all these. triton-lang#4895 is the first step in this direction.
This PR includes triton-lang#4891 and triton-lang#4895. I will rebase once those have landed. It includes a number of hacks to work around bugs in `DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be easy to grep for. @Jokeren is working on a comprehensive revamp of `DotOperandEncodingAttr` which will get rid of all these. triton-lang#4895 is the first step in this direction.
This PR includes #4891 and #4895. I will rebase once those have landed.
It includes a number of hacks to work around bugs in
DotOperandEncodingAttr
. All these are marked asFIXME [Dot LL]
to be easy to grep for. @Jokeren is working on a comprehensive revamp ofDotOperandEncodingAttr
which will get rid of all these. #4895 is the first step in this direction.