-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMD] Define an extract slice operation #4804
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I have a bunch of inling comments. Major issues include refine the semantics of the op and adding more docs/tests for it.
third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp
Outdated
Show resolved
Hide resolved
third_party/amd/lib/TritonAMDGPUDialectToLLVM/ViewSliceOpToLLVM.cpp
Outdated
Show resolved
Hide resolved
third_party/amd/lib/TritonAMDGPUDialectToLLVM/ViewSliceOpToLLVM.cpp
Outdated
Show resolved
Hide resolved
third_party/amd/lib/TritonAMDGPUDialectToLLVM/ViewSliceOpToLLVM.cpp
Outdated
Show resolved
Hide resolved
third_party/amd/lib/TritonAMDGPUDialectToLLVM/ViewSliceOpToLLVM.cpp
Outdated
Show resolved
Hide resolved
third_party/amd/lib/TritonAMDGPUDialectToLLVM/ViewSliceOpToLLVM.cpp
Outdated
Show resolved
Hide resolved
Thanks for meticulously going through the code changes! I have a new commit to address the comments. And I have also moved the python test case to the new proposed location. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, much better!
third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td
Outdated
Show resolved
Hide resolved
third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td
Outdated
Show resolved
Hide resolved
third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h
Outdated
Show resolved
Hide resolved
third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td
Outdated
Show resolved
Hide resolved
f93367b
to
6a23263
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! Some final comments inlined. Also could you git merge origin/main
so that I can trigger CI? Right now cannot due to not using latest main
.
@hmalgewatta thanks for working on this and @antiagainst thanks for detailed review and feedback. Looks great now :) |
Hi @antiagainst in the most recent commit I renamed the test file, added code checking for non-static cases, added more lit tests for failing non-static cases and changes to avoid conflicts. I've also synced my fork with the main branch so that you can trigger CI |
931cd5c
to
5897e00
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please git merge origin/main
again and resolve the conflicts.
393a6f0
to
6be6f71
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The design of the op is significantly different than what I would have imagined so maybe I'm missing some context.
Let me know if my comment makes sense otherwise maybe we need to discuss this a bit more
third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td
Outdated
Show resolved
Hide resolved
da4e954
to
6787e58
Compare
Location loc = op->getLoc(); | ||
auto srcTy = cast<RankedTensorType>(op.getSource().getType()); | ||
auto srcLayout = srcTy.getEncoding(); | ||
auto srcShape = srcTy.getShape(); | ||
auto resultTy = cast<RankedTensorType>(op.getType()); | ||
auto vals = unpackLLElements(loc, adaptor.getSource(), rewriter); | ||
auto elemsPerThread = triton::gpu::getElemsPerThread(srcTy); | ||
auto sizePerThread = triton::gpu::getSizePerThread(srcLayout); | ||
auto totalSizePerThread = sizePerThread[0] * sizePerThread[1]; | ||
auto order = triton::gpu::getOrder(srcLayout); | ||
|
||
// Calculate valid total number of workers in each dimension | ||
auto shapePerCTA = triton::gpu::getShapePerCTATile(srcLayout, srcShape); | ||
shapePerCTA[0] = | ||
std::min(static_cast<unsigned>(srcShape[0]), shapePerCTA[0]); | ||
shapePerCTA[1] = | ||
std::min(static_cast<unsigned>(srcShape[1]), shapePerCTA[1]); | ||
|
||
// Rank == 2 checked in the verifier | ||
SmallVector<int64_t, 2> sizes; | ||
for (auto i = 0; i < 2; ++i) { | ||
sizes.push_back(resultTy.getDimSize(i)); | ||
} | ||
|
||
auto offsets = op.getStaticOffsets(); | ||
|
||
// Calculate offsets and sizes in terms of CTA units. | ||
std::vector<int64_t> CTAOffsets{offsets[0] / shapePerCTA[0], | ||
offsets[1] / shapePerCTA[1]}; | ||
std::vector<int64_t> CTASizes{sizes[0] / shapePerCTA[0], | ||
sizes[1] / shapePerCTA[1]}; | ||
std::vector<int64_t> CTAPerShape{srcShape[0] / shapePerCTA[0], | ||
srcShape[1] / shapePerCTA[1]}; | ||
|
||
// The diagram above illustrates the graphical representation of the | ||
// skipElems, tensorStride, and lastIdx variables. | ||
auto skipElems = CTAOffsets[order[1]] * | ||
(elemsPerThread[order[0]] * sizePerThread[order[1]]) + | ||
CTAOffsets[order[0]] * totalSizePerThread; | ||
auto tensorStride = | ||
(CTAPerShape[order[0]] - CTASizes[order[0]]) * totalSizePerThread; | ||
auto lastIdx = | ||
(CTAOffsets[order[1]] + CTASizes[order[1]] - 1) * | ||
elemsPerThread[order[0]] * sizePerThread[order[1]] + | ||
(CTAOffsets[order[0]] + CTASizes[order[0]]) * totalSizePerThread; | ||
|
||
assert(lastIdx <= vals.size()); | ||
|
||
SmallVector<Value> resultVals; | ||
for (int i = skipElems; i < lastIdx; i += tensorStride) { | ||
for (int j = 0; j < totalSizePerThread * CTASizes[order[0]]; ++j, ++i) { | ||
assert(i < lastIdx); | ||
resultVals.push_back(vals[i]); | ||
} | ||
} | ||
Value ret = packLLElements(loc, this->getTypeConverter(), resultVals, | ||
rewriter, resultTy); | ||
|
||
rewriter.replaceOp(op, ret); | ||
return success(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this lowering makes assumptions on the layout and shape of the of the operands/destination that are strong than what is in the verifier right?
Where do we check that those are true. It is okay to fail lowering if we don't want to support some cases but we never want to miscompile.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! +1. We should check that each thread is still holding the same elements in the op verifier.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@antiagainst the additional constraints I've identified to be added to verifier is making sure the divisors are not zero. I think moving the assert (line 107) would create duplicated code in the verifier, but I can move it there if that's the preference. I'm stuck on how to check that elements each thread is holding is as same as the verifier. Could you point me to how this could be done? And also if there's any other checks I might have missed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the late reply; gotten distracted previously.. After addressing #4804 (comment), looks the current logic should be enough to guarantee that after slicing, threads are handling slides of the original elements without exchange/duplication. So this reads fine to me now @ThomasRaoux. Let me know if you still think some parts are missing.
auto offsets = op.getStaticOffsets(); | ||
|
||
// Calculate offsets and sizes in terms of CTA units. | ||
std::vector<int64_t> CTAOffsets{offsets[0] / shapePerCTA[0], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefer to not mix std::vector
with SmallVector
. Just using std::array<2>
works here.
Location loc = op->getLoc(); | ||
auto srcTy = cast<RankedTensorType>(op.getSource().getType()); | ||
auto srcLayout = srcTy.getEncoding(); | ||
auto srcShape = srcTy.getShape(); | ||
auto resultTy = cast<RankedTensorType>(op.getType()); | ||
auto vals = unpackLLElements(loc, adaptor.getSource(), rewriter); | ||
auto elemsPerThread = triton::gpu::getElemsPerThread(srcTy); | ||
auto sizePerThread = triton::gpu::getSizePerThread(srcLayout); | ||
auto totalSizePerThread = sizePerThread[0] * sizePerThread[1]; | ||
auto order = triton::gpu::getOrder(srcLayout); | ||
|
||
// Calculate valid total number of workers in each dimension | ||
auto shapePerCTA = triton::gpu::getShapePerCTATile(srcLayout, srcShape); | ||
shapePerCTA[0] = | ||
std::min(static_cast<unsigned>(srcShape[0]), shapePerCTA[0]); | ||
shapePerCTA[1] = | ||
std::min(static_cast<unsigned>(srcShape[1]), shapePerCTA[1]); | ||
|
||
// Rank == 2 checked in the verifier | ||
SmallVector<int64_t, 2> sizes; | ||
for (auto i = 0; i < 2; ++i) { | ||
sizes.push_back(resultTy.getDimSize(i)); | ||
} | ||
|
||
auto offsets = op.getStaticOffsets(); | ||
|
||
// Calculate offsets and sizes in terms of CTA units. | ||
std::vector<int64_t> CTAOffsets{offsets[0] / shapePerCTA[0], | ||
offsets[1] / shapePerCTA[1]}; | ||
std::vector<int64_t> CTASizes{sizes[0] / shapePerCTA[0], | ||
sizes[1] / shapePerCTA[1]}; | ||
std::vector<int64_t> CTAPerShape{srcShape[0] / shapePerCTA[0], | ||
srcShape[1] / shapePerCTA[1]}; | ||
|
||
// The diagram above illustrates the graphical representation of the | ||
// skipElems, tensorStride, and lastIdx variables. | ||
auto skipElems = CTAOffsets[order[1]] * | ||
(elemsPerThread[order[0]] * sizePerThread[order[1]]) + | ||
CTAOffsets[order[0]] * totalSizePerThread; | ||
auto tensorStride = | ||
(CTAPerShape[order[0]] - CTASizes[order[0]]) * totalSizePerThread; | ||
auto lastIdx = | ||
(CTAOffsets[order[1]] + CTASizes[order[1]] - 1) * | ||
elemsPerThread[order[0]] * sizePerThread[order[1]] + | ||
(CTAOffsets[order[0]] + CTASizes[order[0]]) * totalSizePerThread; | ||
|
||
assert(lastIdx <= vals.size()); | ||
|
||
SmallVector<Value> resultVals; | ||
for (int i = skipElems; i < lastIdx; i += tensorStride) { | ||
for (int j = 0; j < totalSizePerThread * CTASizes[order[0]]; ++j, ++i) { | ||
assert(i < lastIdx); | ||
resultVals.push_back(vals[i]); | ||
} | ||
} | ||
Value ret = packLLElements(loc, this->getTypeConverter(), resultVals, | ||
rewriter, resultTy); | ||
|
||
rewriter.replaceOp(op, ret); | ||
return success(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the late reply; gotten distracted previously.. After addressing #4804 (comment), looks the current logic should be enough to guarantee that after slicing, threads are handling slides of the original elements without exchange/duplication. So this reads fine to me now @ThomasRaoux. Let me know if you still think some parts are missing.
matchAndRewrite(amdgpu::ExtractSliceOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
auto srcTy = op.getSource().getType(); | ||
if (isa<BlockedEncodingAttr>(op.getSource().getType().getEncoding()) || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isa
can support multiple attributes in <...>
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this reads fine to me now @ThomasRaoux. Let me know if you still think some parts are missing.
Yes I think this is fine. I got thrown off by some of the naming in the verifier but I think it is correct
I approved but please address the remaining points before merging |
fe79ae9
to
f7d04fa
Compare
} | ||
|
||
auto srcShape = srcTy.getShape(); | ||
auto shapePerCTA = mlir::triton::gpu::getShapePerCTATile(srcLayout, srcShape); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought I commented on this but I don't see it anymore sorry if it is a duplicate. Can we find a better name for this variable. In the rest of the code shapePerCTA has a very different meaning, here you are only getting one "tile" of the shape per CTA, shapePerCTA means the sub tensor owned by a CTA
f7d04fa
to
dda5813
Compare
Sorry I pushed a version that was not well merged. I'll correct this |
Introduces a new operation for amdgpus to slice a tensor in memory - Adds new TritonAMDGPUDialect operation ViewSliceOp - Adds verifier for ViewSliceOp - Adds conversion of the operation to llvm
560b261
to
36c425f
Compare
This commit introduces a extract_slice operation for AMD backend
to enable view a slice of a tensor in registers without data exchange.
It enables breaking down large tiles of tensors into smaller ones
for better instruction interleaving and scheduling.
This can be useful for hiding global memory latency when a global
load/store can be efficiently split into several loads/stores to be
overlapped with compute fo attention.
I am not making a trivial change, such as fixing a typo in a comment.
I have written a PR description following these
rules.
I have run
pre-commit run --from-ref origin/main --to-ref HEAD
.Select one of the following.
/test
forlit
tests/unittest
for C++ tests/python/test
for end-to-end testsFILL THIS IN
.Select one of the following.
lit
tests.lit
tests I have added follow these best practices,including the "tests should be minimal" section. (Usually running Python code
and using the instructions it generates is not minimal.)